【算法】手撕快速排序

发布于:2025-03-31 ⋅ 阅读:(18) ⋅ 点赞:(0)

快速排序的思想

任取一个元素作为枢轴,然后想办法把这个区间划分为两部分,小于等于枢轴的放左边,大于等于枢轴的放右边

然后递归处理左右区间,直到空或只剩一个

具体动画演示详见 

数据结构合集 - 快速排序(算法过程, 效率分析, 稳定性分析)

 Lomuto 分区方案(单边扫描法)

public static void quickSort(int[] nums){
    subSort(nums, 0, nums.length-1);
}

private static void subSort(int[] nums, int low, int high){
    if(low < high){
        int pos = partition(nums, low, high);
        subSort(nums, low, pos-1);
        subSort(nums, pos+1, high);
    }
}



private static int partition(int[] nums, int low, int high){
    int pivot = nums[high];

    int i = low-1;

    for(int j=low; j<high; j++){
        if(nums[j] <= pivot){
            i++;
            int tmp = nums[i];
            nums[i] = nums[j];
            nums[j] = tmp;
        }
    }
    // 将pivot放到正确的位置
    int temp = nums[i + 1];
    nums[i + 1] = nums[high];
    nums[high] = temp;
    return i+1;
}

进一步优化上述代码

优化点:小数组时改用插入排序

当待排序数组较小的时候,快速排序的递归开销会比插入排序更大

public static void quickSort(int[] nums){
        subSort(nums, 0, nums.length-1);
    }

    private static void subSort(int[] nums, int low, int high) {
        if (low < high) {
            // 小数组改用插入排序(阈值可调整,通常 10~20)
            if (high - low < 10) {
                insertionSort(nums, low, high);
                return;
            }
            int pos = partition(nums, low, high);
            subSort(nums, low, pos - 1);
            subSort(nums, pos + 1, high);
        }
    }

    private static void insertionSort(int[] nums, int low, int high) {
        for (int i = low + 1; i <= high; i++) {
            int key = nums[i];
            int j = i - 1;
            while (j >= low && nums[j] > key) {
                nums[j + 1] = nums[j];
                j--;
            }
            nums[j + 1] = key;
        }
    }

    private static int partition(int[] nums, int low, int high){
        int pivot = nums[high];

        int i = low-1;

        for(int j=low; j<high; j++){
            if(nums[j] <= pivot){
                i++;
                int tmp = nums[i];
                nums[i] = nums[j];
                nums[j] = tmp;
            }
        }
        // 将pivot放到正确的位置
        int temp = nums[i + 1];
        nums[i + 1] = nums[high];
        nums[high] = temp;
        return i+1;
    }

优化点:三数取中法选择pivot避免最坏情况

如果pivot选择最左或最右,在已排序或接近排序的数组上会导致最坏的情况O(n^{2}

public static void quickSort(int[] nums){
        subSort(nums, 0, nums.length-1);
    }

    private static void subSort(int[] nums, int low, int high) {
        if (low < high) {
            // 小数组改用插入排序(阈值可调整,通常 10~20)
            if (high - low < 10) {
                insertionSort(nums, low, high);
                return;
            }
            int pos = partition(nums, low, high);
            subSort(nums, low, pos - 1);
            subSort(nums, pos + 1, high);
        }
    }

    private static void insertionSort(int[] nums, int low, int high) {
        for (int i = low + 1; i <= high; i++) {
            int key = nums[i];
            int j = i - 1;
            while (j >= low && nums[j] > key) {
                nums[j + 1] = nums[j];
                j--;
            }
            nums[j + 1] = key;
        }
    }

    private static int partition(int[] nums, int low, int high){
        // 三数取中法选择 pivot
        int mid = low + (high - low) / 2;
        if (nums[low] > nums[mid]) swap(nums, low, mid);
        if (nums[low] > nums[high]) swap(nums, low, high);
        if (nums[mid] > nums[high]) swap(nums, mid, high);
        // 将中位数放到 nums[high]
        swap(nums, mid, high);
        int pivot = nums[high];

        int i = low-1;

        for(int j=low; j<high; j++){
            if(nums[j] <= pivot){
                i++;
                swap(nums, i, j);
            }
        }
        // 将pivot放到正确的位置
        swap(nums, i + 1, high);
        return i+1;
    }

    private static void swap(int[] nums, int i, int j) {
        int tmp = nums[i];
        nums[i] = nums[j];
        nums[j] = tmp;
    }

优化点:尾递归优化避免栈溢出

public static void quickSort(int[] nums){
        subSort(nums, 0, nums.length-1);
    }

    private static void subSort(int[] nums, int low, int high) {
        // 小数组改用插入排序
        if (high - low < 10) {
            insertionSort(nums, low, high);
            return;
        }
        // 尾递归优化
        while (low < high) {
            int pos = partition(nums, low, high);
            if (pos - low < high - pos) {
                subSort(nums, low, pos - 1);
                low = pos + 1;
            } else {
                subSort(nums, pos + 1, high);
                high = pos - 1;
            }
        }
    }

    private static void insertionSort(int[] nums, int low, int high) {
        for (int i = low + 1; i <= high; i++) {
            int key = nums[i];
            int j = i - 1;
            while (j >= low && nums[j] > key) {
                nums[j + 1] = nums[j];
                j--;
            }
            nums[j + 1] = key;
        }
    }

    private static int partition(int[] nums, int low, int high){
        // 三数取中法选择 pivot
        int mid = low + (high - low) / 2;
        if (nums[low] > nums[mid]) swap(nums, low, mid);
        if (nums[low] > nums[high]) swap(nums, low, high);
        if (nums[mid] > nums[high]) swap(nums, mid, high);
        // 将中位数放到 nums[high]
        swap(nums, mid, high);
        int pivot = nums[high];

        int i = low-1;

        for(int j=low; j<high; j++){
            if(nums[j] <= pivot){
                i++;
                swap(nums, i, j);
            }
        }
        // 将pivot放到正确的位置
        swap(nums, i + 1, high);
        return i+1;
    }

    private static void swap(int[] nums, int i, int j) {
        int tmp = nums[i];
        nums[i] = nums[j];
        nums[j] = tmp;
    }

优化点:双轴快排

java的Arrays.sort对小数组用插入排序,对大数组用双轴快排(比单轴更快)

双轴快排的基本思路:

  1. 选取两个 pivot
    • pivot1 = nums[left](较小的 pivot)
    • pivot2 = nums[right](较大的 pivot)
    • 确保 pivot1 ≤ pivot2(否则交换)
  1. 分区
    • [left, i)< pivot1
    • [i, k)≥ pivot1 && ≤ pivot2
    • [k, j]:未处理区间
    • (j, right]> pivot2
  1. 递归处理三个子数组
    • [left, i-1](小于 pivot1 的部分)
    • [i, j](介于 pivot1pivot2 之间的部分)
    • [j+1, right](大于 pivot2 的部分)
public static void quickSort(int[] nums){
        dualPivotQuickSort(nums, 0, nums.length-1);
    }

    private static void dualPivotQuickSort(int[] nums, int left, int right) {
        if (left >= right) return;

        // 确保 pivot1 ≤ pivot2
        if (nums[left] > nums[right]) {
            swap(nums, left, right);
        }
        int pivot1 = nums[left];
        int pivot2 = nums[right];

        int i = left + 1;  // [left+1, i) 存储 < pivot1 的元素
        int k = left + 1;  // [i, k) 存储 ≥ pivot1 && ≤ pivot2 的元素
        int j = right - 1; // (j, right-1] 存储 > pivot2 的元素

        while (k <= j) {
            if (nums[k] < pivot1) {
                swap(nums, i, k);
                i++;
                k++;
            } else if (nums[k] <= pivot2) {
                k++;
            } else {
                swap(nums, k, j);
                j--;
            }
        }

        // 将 pivot1 和 pivot2 放到正确位置
        swap(nums, left, i - 1);
        swap(nums, right, j + 1);

        // 递归处理三个子数组
        dualPivotQuickSort(nums, left, i - 2);   // < pivot1
        dualPivotQuickSort(nums, i, j);          // pivot1 ≤ x ≤ pivot2
        dualPivotQuickSort(nums, j + 2, right);  // > pivot2
    }

    private static void swap(int[] nums, int i, int j) {
        int tmp = nums[i];
        nums[i] = nums[j];
        nums[j] = tmp;
    }