lc凸包二分算法题

原题挺烂的, 卡常. 不过看到一种凸包二分的算法挺有意思的, 记录一下.

3494. 酿造药水需要的最少总时间

给你两个长度分别为 n 和 m 的整数数组 skillmana 。

在一个实验室里,有 n 个巫师,他们必须按顺序酿造 m 个药水。每个药水的法力值为 mana[j],并且每个药水 必须 依次通过 所有 巫师处理,才能完成酿造。第 i 个巫师在第 j 个药水上处理需要的时间为 timeij = skill[i] * mana[j]

由于酿造过程非常精细,药水在当前巫师完成工作后 必须 立即传递给下一个巫师并开始处理。这意味着时间必须保持 同步,确保每个巫师在药水到达时 马上 开始工作。

返回酿造所有药水所需的 最短 总时间。

示例 1:

输入: skill = [1,5,2,4], mana = [5,1,4,2]

输出: 110

举个例子,为什么巫师 0 不能在时间 t = 52 前开始处理第 1 个药水,假设巫师们在时间 t = 50 开始准备第 1 个药水。时间 t = 58 时,巫师 2 已经完成了第 1 个药水的处理,但巫师 3 直到时间 t = 60 仍在处理第 0 个药水,无法马上开始处理第 1个药水。

示例 2:

输入: skill = [1,1,1], mana = [1,1,1]

输出: 5

解释:

  1. 第 0 个药水的准备从时间 t = 0 开始,并在时间 t = 3 完成。
  2. 第 1 个药水的准备从时间 t = 1 开始,并在时间 t = 4 完成。
  3. 第 2 个药水的准备从时间 t = 2 开始,并在时间 t = 5 完成。

示例 3:

输入: skill = [1,2,3,4], mana = [1,2]

输出: 21

提示:

  • n == skill.length
  • m == mana.length
  • 1 <= n, m <= 5000
  • 1 <= mana[i], skill[i] <= 5000

https://leetcode.cn/problems/find-the-minimum-amount-of-time-to-brew-potions/solutions/

基础解法

这题稍微观察下就可以发现, 每一行都和上一行有关, 所以一个O(mn)就可以解决. 我写的第一版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution:
def minTime(self, skill: List[int], mana: List[int]) -> int:
prev = [0] * len(skill)
def max(x, y):
return x if x >= y else y
for idx, m in enumerate(mana):
f = [0] * len(skill)
mx = prev[0]
for i, s in enumerate(skill):
f[i] = f[i - 1] + s * m
if i < len(skill) - 1:
mx = max(mx, prev[i + 1] - f[i])
for i in range(len(skill)):
f[i] = f[i] + mx
prev = f
return f[-1]

国服被卡常了, 美服AC了.

后续看灵神的解答发现还有几种有意思的方案.

递推解法

药水每次必须连续制作, 所以知道开始时间我们就可以算出结束时间.

第i个巫师完成第j瓶药水的公式如下:

如果我们可以从 $ start_{i-1} $ 推出 $ start_{i-1} $ 就好了

从上一次到下一次, 我们其实总是卡在了最慢的那个巫师. 所以对于最慢的巫师我们有:

此时把lastfinish全部代入即有:

由此我们有代码如下:

1
2
3
4
5
6
7
8
class Solution:
def minTime(self, skill: List[int], mana: List[int]) -> int:
n = len(skill)
s = list(accumulate(skill, initial=0)) # skill 的前缀和
start = 0
for pre, cur in pairwise(mana):
start += max(pre * s[i + 1] - cur * s[i] for i in range(n))
return start + mana[-1] * s[-1]

逆序优化

上面的公式还能拆开

我们可以以d = mana[j-1] - mana[j] 做分类讨论

d>0时, 带i的项s[i]递增, d也递增; 如果skill也递增, 那么最大值只能是递增skill的最后一项. 对此我们可以倒序遍历, 一旦倒序找到更大的数, 我们就记录下来, 然后我们遍历这个列表就行了.

d<0时, 类似的, 我们顺序遍历, 找到更大的数就记录下来.

最后依据d来做判断即可.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
def minTime(self, skill: List[int], mana: List[int]) -> int:
n = len(skill)
s = list(accumulate(skill, initial=0))

suf_record = [n - 1]
for i in range(n - 2, -1, -1):
if skill[i] > skill[suf_record[-1]]:
suf_record.append(i)

pre_record = [0]
for i in range(1, n):
if skill[i] > skill[pre_record[-1]]:
pre_record.append(i)

start = 0
for pre, cur in pairwise(mana):
record = pre_record if pre < cur else suf_record
start += max(pre * s[i + 1] - cur * s[i] for i in record)
return start + mana[-1] * s[-1]

凸包+二分

前面我们已经提取出了公式:

(mana[j−1]−mana[j])⋅s[i]+mana[j−1]⋅skill[i]

我们完全可以改写成点积的形式

v[i] = (s[i], skill[i]) p = (mana[j−1]−mana[j],mana[j−1])

我们要求的即为 $ max_{0}^{n-1}(p \cdot v[i]) $

这里p是一个定值, 而他们的点积则取决于v[i]在p上的投影.

考虑v构成的凸包, 凸包内的点比凸包顶点的投影短, 所以我们把问题转化成了求v的凸包顶点.

此时我们发现, 如果我们顺序或者逆序遍历凸包上的点, 我们的点积构成了一个单峰函数, 由此我们可以二分来解决.

凸包这里使用andrew算法即可, 其最大复杂度在于排序, 但是我们这里前缀和一定是递增的, 所以相当于已经排序完了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Vec:
__slots__ = 'x', 'y'

def __init__(self, x: int, y: int):
self.x = x
self.y = y

def __sub__(self, b: "Vec") -> "Vec":
return Vec(self.x - b.x, self.y - b.y)

def det(self, b: "Vec") -> int:
return self.x * b.y - self.y * b.x

def dot(self, b: "Vec") -> int:
return self.x * b.x + self.y * b.y

class Solution:
# Andrew 算法,计算 points 的上凸包
# 由于横坐标(前缀和)是严格递增的,所以无需排序
def convex_hull(self, points: List[Vec]) -> List[Vec]:
q = []
for p in points:
while len(q) > 1 and (q[-1] - q[-2]).det(p - q[-1]) >= 0:
q.pop()
q.append(p)
return q

def minTime(self, skill: List[int], mana: List[int]) -> int:
s = list(accumulate(skill, initial=0))
vs = [Vec(pre_sum, x) for pre_sum, x in zip(s, skill)]
vs = self.convex_hull(vs) # 去掉无用数据

start = 0
for pre, cur in pairwise(mana):
p = Vec(pre - cur, pre)
# p.dot(vs[i]) 是个单峰函数,二分找最大值
check = lambda i: p.dot(vs[i]) > p.dot(vs[i + 1])
i = bisect_left(range(len(vs) - 1), True, key=check)
start += p.dot(vs[i])
return start + mana[-1] * s[-1]

最后算法复杂度O(n+mlogn)

以前没见过这种做法, 记录一下.