原题:https://leetcode.cn/problems/rotate-array/description/
给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。
示例 1:
输入: nums = [1,2,3,4,5,6,7], k = 3
输出: [5,6,7,1,2,3,4]
解释:
向右轮转 1 步: [7,1,2,3,4,5,6]
向右轮转 2 步: [6,7,1,2,3,4,5]
向右轮转 3 步: [5,6,7,1,2,3,4]
示例 2:
输入:nums = [-1,-100,3,99], k = 2
输出:[3,99,-1,-100]
解释:
向右轮转 1 步: [99,-1,-100,3]
向右轮转 2 步: [3,99,-1,-100]
开始的思路是通过mod运算完成循环的效果,从第一个位置元素起逐渐以步长k往下一个元素循环递推赋值,直到回到第一个位置元素。
from typing import List
class Solution:
def rotate(self, nums: List[int], k: int) -> None:
"""
Do not return anything, modify nums in-place instead.
"""
arr_len = len(nums)
curr_idx, curr_num = 0, nums[0]
while True:
next_idx = (curr_idx + k) % arr_len
if next_idx == 0:
nums[next_idx] = curr_num
break
nums[next_idx], curr_num = curr_num, nums[next_idx]
curr_idx = next_idx
print(nums)
if __name__ == '__main__':
s = Solution()
s.rotate(nums=[1, 2, 3, 4, 5, 6, 7], k=3)
s.rotate(nums=[-1, -100, 3, 99], k=2)
# [5, 6, 7, 1, 2, 3, 4]
# [3, -100, -1, 99]
可以看到第一个输出正确,第二个输出错误。
原因
上面方式只适用于数组长度n和步长k互质的情况。主要可以通过下面两点证明来理解:
1.循环数组中,从任意位置经过有限步步长移动后,一定会再次回到起始位置。
假设起始位置下标
i
,其实就是需要证明(i + mk) mod n = i
,也就是证明mk mod n = 0
,m
是至少需要移动的次数。
分两种情况:
1)n
、k
互质,此时最小的移动次数显然等于n
,gcd(n, k)=1
。
2)n
、k
不互质,设最大公约数gcd(n, k)=d
, n ′ = n d n'=\frac{n}{d} n′=dn, k ′ = k d k'=\frac{k}{d} k′=dk,其中 n ′ 、 k ′ n'、k' n′、k′互质。此时上面等式等价于 m d k ′ m o d d n ′ = 0 mdk'\mod dn' = 0 mdk′moddn′=0,化简为=> m k ′ m o d n ′ = 0 mk' \mod n'=0 mk′modn′=0,因此当m= n ′ n' n′的时候成立。
综上,得证。
2.循环数组中环的数量等于数组长度n和步长k的最大公约数。
1中的证明其实说明从每个下标起始,m次移动后又会再次回到这个下标,这个过程其实构成了一个环,移动的次数就是环中的元素数量。并且因为是等距移动,所以环和环之间的元素也不会冲突。
同样分为n
、k
互质和不互质两种情况:
1)互质:根据1,需要移动n次,数组长度n,所以环的数量1。
2)不互质:同样根据1,需要移动 n ′ n' n′次,所以每个环中的元素数量也是 n ′ n' n′,所以环的数量= n n ′ \frac{n}{n'} n′n=d。
正确答案
求出环的数量,遍历环的起始下标。
from typing import List
class Solution:
def rotate(self, nums: List[int], k: int) -> None:
"""
Do not return anything, modify nums in-place instead.
"""
import math
arr_len = len(nums)
gcd_ans = math.gcd(arr_len, k)
for idx in range(gcd_ans): # 每个环的起始下标
curr_idx, curr_num = idx, nums[idx]
while True:
next_idx = (curr_idx + k) % arr_len
if next_idx == idx:
nums[next_idx] = curr_num
break
nums[next_idx], curr_num = curr_num, nums[next_idx]
curr_idx = next_idx
print(nums)
if __name__ == '__main__':
s = Solution()
s.rotate(nums=[1, 2, 3, 4, 5, 6, 7], k=3)
s.rotate(nums=[-1, -100, 3, 99], k=2)
# [5, 6, 7, 1, 2, 3, 4]
# [3, 99, -1, -100]