排序算法的Python实现以及时间复杂度分析

我用Python实现了冒泡排序、选择排序、插入排序、归并排序、快速排序。然后简单讲了讲快速排序的优化,我们可以通过小数组采用插入排序来减少递归的开销;对于有一定顺序的数组,我采用三数取中来提高性能;对于包含大量重复数的数组,我用了三路快速排序来提高性能。

最后,我把这些排序算法应用在随机数组、升序数组、降序数组、包含大量重复数的数组上,比较了一下它们的耗时。

冒泡排序

冒泡排序的大体思想就是通过与相邻元素的比较和交换来把小的数交换到最前面。这个过程类似于水泡向上升一样,因此而得名。

1
2
3
4
5
6
7
8
9
10
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp

def BubbleSort(nums):
for i in range(len(nums)-1):
for j in range(len(nums)-i-1):
if nums[j] > nums[j+1]:
exchange(nums,j,j+1)
  • 时间复杂度O(n^2)

选择排序

首先,找到数组中最小的那个元素,然后将它和数组的第一个元素交换位置(如果第一个元素就是最小元素那么它就和自己交换)。然后在剩下的元素中找到最小的元素,将它与数组的第二个元素交换位置。如此往复,直到将整个数组排序。这种方法叫做选择排序,因为它在不断地选择剩余元素之中的最小者

1
2
3
4
5
6
7
8
9
10
11
12
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp

def SelectSort(nums):
for i in range(len(nums)):
minIndex = i
for j in range(i,len(nums)):
if nums[j] < nums[minIndex]:
minIndex = j
exchange(nums,i,minIndex)
  • 时间复杂度O(n^2)

插入排序

通常人们整理桥牌的方法是一张一张的来,将每一张牌插入到其他已经有序的牌中的适当位置。在计算机的实现中,为了给要插入的元素腾出空间,我们需要将其余所有元素在插入之前都向右移动一位。这种算法叫做插入排序。

1
2
3
4
5
6
7
8
9
10
11
12
def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp

def InsertSort(nums):
for i in range(len(nums)-1):
j = i + 1
while i >= 0 and nums[i] > nums[j]:
exchange(nums,i,j)
j -= 1
i -= 1
  • 时间复杂度O(n^2)

归并排序

归并排序体现的是一种分治思想(Divide and conquer),下面是其排序的步骤:

  1. 将数组一分为二(Divide array into two halves)
  2. 对每部分进行递归式地排序(Recursively sort each half)
  3. 合并两个部分(Merge two halves)

merge()函数

具体步骤如下:

  1. 给出原数组a[],该数组的low到mid,mid+1到high的子数组是各自有序的。
  2. 将数组复制到辅助数组(auxiliary array)中,两部分数组的首元素分别以i和j为下标,给原数组首元素以k为下标。
  3. 比较i下标和j下标的元素,将较小值赋到k下标位置的元素内,然后对k和赋值的下标进行递增。
  4. 重复上述过程,直到比较完全部元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def merge(a,aux,low,mid,high):
i = low
j = mid+1
k = 0
for k in range(low,high+1):
if i > mid:
a[k] = aux[j]
j += 1
elif j > high:
a[k] = aux[i]
i += 1
else:
if aux[i] > aux[j]:
a[k] = aux[j]
j += 1
else:
a[k] = aux[i]
i += 1

sort()函数

我们要对数组a[low..high]进行排序,先将它分为a[low..mid]a[mid+1..high]两部分,分别递归调用将它们单独排序,最后将有序的子数组归并为最终的排序结果。

1
2
3
4
5
6
7
8
def sort(a,aux,low,high):
# 退出条件
if low >= high:
return
mid = (low + high) // 2
sort(a,aux,low,mid)
sort(a,aux,mid+1,high)
merge(a,aux,low,mid,high)

MergeSort()函数

为了保证归并排序函数MergeSort()输入只有未排序的数组,这里调用前面的辅助函数sort():

1
2
3
4
5
6
def MergeSort(nums):
aux = nums.copy()
low = 0
high = len(nums)-1
sort(nums,aux,low,high)
return nums
  • 时间复杂度:O(nlogn)

快速排序

快速排序是一种分治的排序算法。它将一个数组分成两个子数组,将两部分独立地排序。

分治策略指的是:将原问题分解为若干个规模更小但结构与原问题相似的子问题。递归地解这些子问题,然后将这些子问题的解组合为原问题的解。

下面是一个示例:

来源:快速排序python实现

简单实现

下面的代码短小利于理解,但是空间复杂度大,使用了三个列表解析式,而且每次选取进行比较时需要遍历整个序列。

1
2
3
4
5
6
7
8
9
def QuickSort(a):
if len(a) < 2:
return a
else:
pivot = a[0]
less_than_pivot = [x for x in a if x < pivot]
more_than_pivot = [x for x in a if x > pivot]
pivot_list = [x for x in a if x == pivot]
return QuickSort(less_than_pivot) + pivot_list + QuickSort(more_than_pivot)

原地排序实现

  1. 切分——partition()

切分方法:先随意地取a[low]作为切分元素(即那个将会被排定的元素),然后我们从数组的左端开始向右扫描直到找到一个大于等于它的元素,再从数组的右端开始向左扫描直到找到一个小于等于它的元素。这两个元素是没有排定的,因此我们交换它们的位置。如此继续,当两个指针相遇时,我们只需要将切分元素a[low]和左子元素最右侧的元素a[j]交换然后返回j即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def partition(a,low,high):
i = low # 循环内i=i+1
j = high + 1 # 循环内j=j-1
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
i += 1 # 保证i每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[i] < a[low] and i < high:
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
j -= 1 # 保证j每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[j] > a[low] and j > low:
j -= 1

# 如果两个指针交叉,说明已经排序完了
if i >= j:
break

exchange(a,i,j)

# 指针相遇后,j所在的元素小于low,进行互换
exchange(a,low,j)

return j

这里有个细节需要注意下,这个代码相比我最初的代码改变了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def partition(a,low,high):
- i = low + 1
+ i = low # 循环内i=i+1
- j = high
+ j = high + 1 # 循环内j=j-1
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
+ i += 1 # 保证i每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[i] < a[low] and i < high:
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
+ j -= 1 # 保证j每次循环都变化,不会陷入死循环(所有数都相等时这种情况)
while a[j] > a[low] and j > low:
j -= 1

# 如果两个指针交叉,说明已经排序完了
if i >= j:
break

exchange(a,i,j)

# 指针相遇后,j所在的元素小于low,进行互换
exchange(a,low,j)

return j

如果没有这些代码,当碰到[2,2,2]这样的情况时,i和j一直不会改变,永远无法满足if i >= j,然后函数就一直在while True里边死循环。

  1. sort()函数

快速排序递归地将子数组a[low..high]排序,先用partition()方法将a[j]放到一个合适位置,然后再用递归调用将其他位置的元素排序。

1
2
3
4
5
6
def sort(a,low,high):
if low >= high:
return
j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
  1. QuickSort()函数

为了保证快速排序函数QuickSort()输入只有未排序的数组,这里调用前面的辅助函数sort():

1
2
3
4
5
def QuickSort(nums):
low = 0
high = len(nums)-1
sort(nums,low,high)
return nums

快速排序的时间复杂度

  • 最优情况:每一次的基准值都正好为序列的中位数,时间复杂度为nlogn
  • 最坏情况:每一次的基准值都恰好是序列的最大值或最小值,时间复杂度为n^2。有意思的是如果每次选第一个数做基准值,但每次这个数又是最小值,那么序列本身就是有序的,但时间复杂度也是最高的

因此,要想优化时间复杂度,关键在于基准值的选择

快速排序的优化

  1. 优化小数组效率

对于规模很小的情况,快速排序的优势并不明显(可能没有优势),而递归型的算法还会带来额外的开销。于是对于这类情况可以选择非递归型的算法来替代。

那就有两个问题:多小的数组算小数组?替换的算法是什么?

通常这个阈值设定为10,替换的算法一般是插入排序。

下面是Python实现,这里只需要在sort()函数中加一个数组大小判断即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
CUTOFF = 10

def sort(a,low,high):
if low >= high:
return

# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return

j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
  1. 合理选择pivot

前面也讨论过,直接选择分区的第一个或最后一个元素做 pivot 是不合适的。对于已经排好序,或者接近排好序的情况,会进入最差情况,时间复杂度退化到n^2。

pivot选取的理想情况是:让分区中比 pivot 小的元素数量和比 pivot 大的元素数量差不多。较常用的做法是三数取中( median of three ),即从第一项、最后一项、中间一项中取中位数作为 pivot。当然这并不能完全避免最差情况的发生。所以很多时候会采取更小心、更严谨的 pivot 选择方案(对于大数组特别重要)。比如先把大数组平均切分成左中右三个部分,每个部分用三数取中得到一个中位数,再从得到的三个中位数中找出中位数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
CUTOFF = 10

def get_median(nums,low,high):
# 计算数组中间的元素的下标
mid = (low + high) // 2

# 目标: arr[mid] <= arr[high]
if nums[mid] > nums[high]:
exchange(nums,mid,high)
# 目标: arr[low] <= arr[high]
if nums[low] > nums[high]:
exchange(nums,low,high)
# 目标: arr[low] >= arr[mid]
if nums[low] < nums[mid]:
exchange(nums,low,mid)

# 此时,arr[mid] <= arr[low] <= arr[high],low的位置上保存这三个位置中间的值
return nums[low]

def sort(a,low,high):
if low >= high:
return

# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return

# 三数取中(median of three),low的位置上保存这三个位置中间的值
_ = get_median(a,low,high)

j = partition(a,low,high)
sort(a,low,j-1)
sort(a,j+1,high)
  1. 处理重复元素问题

当一个数组里的元素全部一样大(或者存在大量相同元素)会令快速排序进入最差情况,因为不管怎么选 pivot,都会使分区结果一边很大一边很小。

为了解决这个问题,我们需要修改分区过程,思路跟上面说的两路分区(基本的快排)类似,只是现在我们需要小于 pivot、等于 pivot、大于 pivot 三个分区。

举个例子,待分割序列:6 4 6 7 1 6 7 6 8 6,其中pivot=6

  • 未对与key元素相等处理的划分结果:1 4 6 6 7 6 7 6 8 6
    • 下次的两个子序列为:1 4 67 6 7 6 8 6
  • 对与key元素相等处理的划分结果:1 4 6 6 6 6 6 7 8 7
    • 下次的两个子序列为:1 47 8 7

经过对比,我们可以看出,在一次划分后,把与key相等的元素聚在一起,能减少迭代次数,效率会提高不少

具体过程:

如下图,我们可以设置四个游标,左端p、i,右端j、q。i、j的作用跟之前两路划分时候的左右游标相同,就是从两端向中间遍历序列,并将遍历到的元素与pivot比较,如果等于pivot,则移到两端(i对应的元素移到左端,j对应的元素移到右端。移动的方式就是拿此元素和a或d对应的元素进行交换,所以p和q的作用就是记录等于pivot的元素移动过后的边界),反之,如果大于或小于pivot,还按照之前两路划分的方式进行移动。这样一来,中间部分就和两路划分相同,两头是等于pivot的部分,我们只需要将这两部分移动到中间即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def partition(a,low,high):
p = low + 1
i = low + 1
j = high
q = high
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
while a[i] <= a[low] and i < high:
# 与pivot相等的元素将其交换到p所在的位置
if a[i] == a[low]:
exchange(a,p,i)
p += 1
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
while a[j] >= a[low] and j > low:
# 与pivot相等的元素将其交换到q所在的位置
if a[j] == a[low]:
exchange(a,j,q)
q -= 1
j -= 1

# 如果两个指针交叉,说明已经排序完了
if i >= j:
break

exchange(a,i,j)

# 因为工作指针i指向的是当前需要处理元素的下一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
i -= 1
p -= 1
while p >= low:
exchange(a, i, p)
i -= 1
p -= 1

# 因为工作指针j指向的是当前需要处理元素的上一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
j += 1
q += 1
while q <= high:
exchange(a, q, j)
j += 1
q += 1

return i,j

下面是sort()函数,这里我只写了修改的部分:

1
2
3
4
5
6
7
def sort(a,low,high):

# ...

i,j = partition(a,low,high)
sort(a,low,i)
sort(a,j,high)

整体代码实现

下面是经过优化的快速排序代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
CUTOFF = 10

def exchange(a,i,j):
temp = a[i]
a[i] = a[j]
a[j] = temp

def InsertSort(nums):
for i in range(len(nums)-1):
j = i + 1
while i >= 0 and nums[i] > nums[j]:
exchange(nums,i,j)
j -= 1
i -= 1

def partition(a,low,high):
p = low + 1
i = low + 1
j = high
q = high
while True:
# 如果a[i]比基准数小,则后移一位直到有大于等于基准数的数出现
while a[i] <= a[low] and i < high:
# 与pivot相等的元素将其交换到p所在的位置
if a[i] == a[low]:
exchange(a,p,i)
p += 1
i += 1
# 如果a[j]比基准数大,则前移一位直到有小于等于基准数的数出现
while a[j] >= a[low] and j > low:
# 与pivot相等的元素将其交换到q所在的位置
if a[j] == a[low]:
exchange(a,j,q)
q -= 1
j -= 1

# 如果两个指针交叉,说明已经排序完了
if i >= j:
break

exchange(a,i,j)

# 因为工作指针i指向的是当前需要处理元素的下一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
i -= 1
p -= 1
while p >= low:
exchange(a, i, p)
i -= 1
p -= 1

# 因为工作指针j指向的是当前需要处理元素的上一个元素,故而需要退回到当前元素的实际位置,然后将等于pivot元素交换到序列中间
j += 1
q += 1
while q <= high:
exchange(a, q, j)
j += 1
q += 1

return i,j

def get_median(nums,low,high):
# 计算数组中间的元素的下标
mid = (low + high) // 2

# 目标: arr[mid] <= arr[high]
if nums[mid] > nums[high]:
exchange(nums,mid,high)
# 目标: arr[low] <= arr[high]
if nums[low] > nums[high]:
exchange(nums,low,high)
# 目标: arr[low] >= arr[mid]
if nums[low] < nums[mid]:
exchange(nums,low,mid)

# 此时,arr[mid] <= arr[low] <= arr[high],low的位置上保存这三个位置中间的值
return nums[low]

def sort(a,low,high):
if low >= high:
return

# 当数组大小小于CUTOFF时,调用插入排序
if high - low <= CUTOFF - 1:
InsertSort(a[low:high+1])
return

# 三数取中(median of three),low的位置上保存这三个位置中间的值
_ = get_median(a,low,high)

i,j = partition(a,low,high)
sort(a,low,i)
sort(a,j,high)

def QuickSort3Ways(nums):
low = 0
high = len(nums)-1
sort(nums,low,high)
return nums

nums = [4,5,6,1,2,3,3,3,1,2]
print(QuickSort(nums))

快速排序和归并排序对比

快速排序和归并排序是互补的:

  • 归并排序:
    1. 将数组分成两个子数组分别排序,并将有序的子数组归并以将整个数组排序;
    2. 递归调用发生在处理整个数组之前;
    3. 一个数组被等分为两半。
  • 快速排序:
    1. 当两个子数组都有序时,整个数组也就自然有序了;
    2. 递归调用发生在处理整个数组之后;
    3. 切分(partition)的位置取决于数组的内容。

各大排序算法测试

下面我们对各大排序算法在不同数据集上进行对比,看看它们的优劣。

计时函数

不同数据集可以用同一个计时函数,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import time

# 计时函数
def count_time(a,sortname):
time_start = time.time()
if sortname == 'BubbleSort':
BubbleSort(a)
if sortname == 'SelectSort':
SelectSort(a)
if sortname == 'InsertSort':
InsertSort(a)
if sortname == 'MergeSort':
MergeSort(a)
if sortname == 'QuickSort':
QuickSort(a)
if sortname == 'QuickSort3Ways':
QuickSort3Ways(a)
time_end = time.time()
return (time_end - time_start)

随机数据集

随机数据生成器:

1
2
3
4
5
6
7
8
9
10
11
12
import random

def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
totalTime += count_time(a,sortname)
return totalTime

这里我们生成一个长度为5000的数组,然后重复测试10次,最后计算各个排序算法用时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
length = 5000
numberOfArrays = 10

print("BubbleSort's total time:")
print(timeRandomInput('BubbleSort',length,numberOfArrays))

print("SelectSort's total time:")
print(timeRandomInput('SelectSort',length,numberOfArrays))

print("InsertSort's total time:")
print(timeRandomInput('InsertSort',length,numberOfArrays))

print("MergeSort's total time:")
print(timeRandomInput('MergeSort',length,numberOfArrays))

print("QuickSort's total time:")
print(timeRandomInput('QuickSort',length,numberOfArrays))

print("QuickSort3Ways's total time:")
print(timeRandomInput('QuickSort3Ways',length,numberOfArrays))
1
2
3
4
5
6
7
8
9
10
11
12
BubbleSort's total time:
30.023681640625
SelectSort's total time:
11.03202223777771
InsertSort's total time:
24.185371160507202
MergeSort's total time:
0.1900651454925537
QuickSort's total time:
0.1554875373840332
QuickSort3Ways's total time:
0.19011521339416504

降序数据集

这里我们看下这些排序算法在降序数据集下的表现,首先改变数据生成函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
import random

def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.sort(reverse = True)
totalTime += count_time(a,sortname)
return totalTime

这里如果生成一个长度为10000的数组,快速排序会出现RecursionError: maximum recursion depth exceeded in comparison错误。这个因为Python中默认的最大递归深度是989。解决方案:手动设置递归调用深度,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import random
+import sys

+sys.setrecursionlimit(1000000)

def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
a.sort(reverse = True)
totalTime += count_time(a,sortname)
return totalTime

数组大小改变为5000,重复10次,下面是测试结果:

1
2
3
4
5
6
7
8
9
10
11
12
BubbleSort's total time:
45.00776267051697
SelectSort's total time:
11.393858909606934
InsertSort's total time:
48.275355100631714
MergeSort's total time:
0.18087530136108398
QuickSort's total time:
14.895536661148071
QuickSort3Ways's total time:
0.10853052139282227

升序数据集

这里我们看下这些排序算法在升序数据集下的表现,首先改变数据生成函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import random
import sys

sys.setrecursionlimit(1000000)

def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.sort(reverse = False)
totalTime += count_time(a,sortname)
return totalTime

同样的,这里数组大小为5000,重复10次,下面是测试结果:

1
2
3
4
5
6
7
8
9
10
11
12
BubbleSort's total time:
14.935291051864624
SelectSort's total time:
11.371372699737549
InsertSort's total time:
0.008459329605102539
MergeSort's total time:
0.15901756286621094
QuickSort's total time:
16.011647939682007
QuickSort3Ways's total time:
0.10053849220275879

含有大量重复数的数组

这里我们看下这些排序算法在含有大量重复数的数据集下的表现,首先改变数据生成函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import random
import sys

sys.setrecursionlimit(1000000)

def timeRandomInput(sortname,length,numberOfArrays):
totalTime = 0
#测试数组数
for _ in range(numberOfArrays):
#数组大小
a = []
for _ in range(length):
- a.append(random.randint(1, 1000000)) # 测试数据范围
+ a.append(random.randint(999990, 1000000)) # 测试数据范围
totalTime += count_time(a,sortname)
return totalTime

同样的,这里数组大小为5000,重复10次,下面是测试结果:

1
2
3
4
5
6
7
8
9
10
11
12
BubbleSort's total time:
28.813392877578735
SelectSort's total time:
11.362754821777344
InsertSort's total time:
22.454782247543335
MergeSort's total time:
0.1563563346862793
QuickSort's total time:
0.15424251556396484
QuickSort3Ways's total time:
0.08862972259521484

总结

BubbleSort SelectSort InsertSort MergeSort QuickSort QuickSort3Ways
随机数据集 30.023 11.032 24.185 0.190 0.155 0.190
升序数据集 14.935 11.371 0.008 0.159 16.011 0.100
降序数据集 45.007 11.393 48.275 0.180 14.895 0.108
大量重复数的数据集 28.813 11.362 22.454 0.156 0.154 0.088

经过优化后的三路快速排序在升序、降序、包含大量重复数的情况下表现均非常优异。

参考


----------- 本文结束啦感谢您阅读 -----------

赞赏一杯咖啡