服务中心的最佳位置(梯度下降)
题意解析
题目更抽象的描述就是,给定点集S,找到一个点A,使得A到S中所有点的欧几里得距离和distance_sum最小,求min_distance_sum。
解法
本质就是个优化问题。 目标函数为f(x,y)=∑0n−1[(x−xi)2+(y−yi)2)]1/2f(x, y) = \sum _{0}^{n-1}[(x-x_i)^2+(y-y_i)^2)]^{1/2}f(x,y)=∑0n−1[(x−xi)2+(y−yi)2)]1/2。
对x,y分别求偏导,
f′(x)=∑0n−1[(x−xi)2+(y−yi)2]−1/2(x−xi)f'(x)=\sum _0^{n-1}[(x-x^i)^2+(y-y^i)^2]^{-1/2}(x-x^i)f′(x)=∑0n−1[(x−xi)2+(y−yi)2]−1/2(x−xi)
f′(y)=∑0n−1[(x−xi)2+(y−yi)2]−1/2(y−yi)f'(y)=\sum _0^{n-1}[(x-x^i)^2+(y-y^i)^2]^{-1/2}(y-y^i)f′(y)=∑0n−1[(x−xi)2+(y−yi)2]−1/2(y−yi)
梯度下降,更新x, y;
new_x=x−alpha∗dx,new_y=y−alpha∗dynew\_x = x - alpha * dx, new\_y = y - alpha * dynew_x=x−alpha∗dx,new_y=y−alpha∗dy,alpha是学习率。
同时为了防止振荡,在振荡发生时,调小学习率。
终止条件是更新后的距离和和更新前比较小于题目给定阈值。
代码
from typing import List class Solution: def getMinDistSum(self, positions: List[List[int]]) -> float: center = positions[0] alpha = 16 precision_error = 1E-10 current_res = self.getDistSum(center, positions) xs = [p[0] for p in positions] ys = [p[1] for p in positions] dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs) while abs(dx) > precision_error or abs(dy) > precision_error: dx, dy = self.getDelta(center[0], xs, center[1], ys), self.getDelta(center[1], ys, center[0], xs) new_center_x = center[0] - alpha * dx new_center_y = center[1] - alpha * dy new_center = [new_center_x, new_center_y] tmp_res = self.getDistSum(new_center, positions) while tmp_res > current_res: if abs(tmp_res - current_res) < precision_error: return tmp_res alpha /= 2 new_center_x = center[0] - alpha * self.getDelta(center[0], xs, center[1], ys) new_center_y = center[1] - alpha * self.getDelta(center[1], ys, center[0], xs) new_center = [new_center_x, new_center_y] tmp_res = self.getDistSum(new_center, positions) current_res = tmp_res center = [new_center_x, new_center_y] # print(last_res - current_res) return current_res def getDelta(self, x, xs, y, ys): res = 0 for i, x_const in enumerate(xs): y_const = ys[i] divisor = pow(pow(x-x_const, 2) + pow(y-y_const, 2), 0.5) if divisor != 0: res += (x - x_const) / divisor return res def getDistSum(self, center, positions): return sum(self.getEuclideDist(center, position) for position in positions) def getEuclideDist(self, a, b): return pow(pow(a[0]-b[0], 2) + pow(a[1]-b[1], 2), 0.5) def main(): inputs = [ [[0, 1], [1, 0], [1, 2], [2, 1]] , [[1, 1], [3, 3]] , [[1, 1]] , [[1, 1], [0, 0], [2, 0]] , [[0, 1], [3, 2], [4, 5], [7, 6], [8, 9], [11, 1], [2, 12]] , [[0, 1], [1, 0], [1, 2], [2, 1], [1, 1]] , [[44, 23], [18, 45], [6, 73], [0, 76], [10, 50], [30, 7], [92, 59], [44, 59], [79, 45], [69, 37], [66, 63], [10, 78], [88, 80], [44, 87]] ] outputs = [ 4 , 2.82843 , 0.00000 , 2.73205 , 32.94036 , 4 , 499.28078 ] sol = Solution() for i, input in enumerate(inputs): actual = sol.getMinDistSum(inputs[i]) print(actual, outputs[i], actual - outputs[i] <= 1E-5) if __name__ == '__main__': main()
作者:Infinity在写代码
链接:https://juejin.cn/post/7024933004996771854