지금까지 기본적인 시간 복잡도가 $\Omicron(n^2)$인 정렬 알고리즘들을 알아봤습니다. 선택 정렬, 삽입 정렬, 버블 정렬은 모두 평균 $\Theta(n^2)$의 시간 복잡도를 가졌고 삽입정렬과 개선된 버블 정렬은 최선의 경우 $\Theta(n)$의 시간 복잡도를 가졌습니다. 이번에는 평균 $\Theta(n\log n)$의 시간 복잡도인 고급 정렬 알고리즘들에 대해 알아봅시다.
병합 정렬
`병합 정렬(merge sort)`은 입력을 반으로 나누어 각각 정렬한 뒤 합치는 재귀 알고리즘입니다. "정렬하는 방법이 반으로 나눈 뒤 각각 정렬하는거면 각각은 어떻게 정렬하라는 거지?" 싶을 수 있습니다. 재귀적으로 입력을 반으로 나눠 정렬하다보면 어느 순간 입력의 크기는 1이 됩니다. 크기가 1인 배열은 그 자체로 정렬된 배열입니다. 병합 정렬을 다시 표현하면 크기가 1이 될 때까지 나눈 뒤 다시 합치며(merge) 정렬하는 방법입니다.
- Pseudocode
mergeSort(A[], first, last){
if( first < last ){
mid ← ⌊(first+last)/2⌋;
mergeSort(A, first, mid);
mergeSort(A, mid+1, last);
merge(A, first, mid, last);
}
}
merge(A[], first, mid, last){
각각 정렬되어 있는 두 배열 A[first ⋯ mid]와 A[mid+1 ⋯ last]를 합쳐
하나의 정렬된 배열 A[first ⋯ last]를 만든다.;
}
merge 함수를 조금 더 자세히 표현하면 아래와 같습니다.
mergeSort(A[], first, last){
if( first < last ){
mid ← ⌊(first+last)/2⌋;
mergeSort(A, first, mid);
mergeSort(A, mid+1, last);
merge(A, first, mid, last);
}
}
merge(A[], first, mid, last){
i ← first; // 왼쪽 배열 A[first ⋯ mid]의 커서
j ← mid+1; // 오른쪽 배열 A[mid+1 ⋯ last]의 커서
k ← 1; // 합친 배열 tmp[first ⋯ last]의 커서
while( i ≤ mid and j ≤ last ) // 크기 순서대로 끼워 넣기
if( A[i] ≤ A[j] ) tmp[k++] ← A[i++];
else tmp[k++] ← A[j++];
while( i ≤ mid ) // 왼쪽 배열이 남은 경우
tmp[k++] ← A[i++];
while( j ≤ last ) // 오른쪽 배열이 남은 경우
tmp[k++] ← A[j++];
i ← first; k ← 1;
while( i ≤ last ) // A에 반영
A[i++] ← tmp[k++];
}
- 시간 복잡도
알고리즘만 보면 "이게 더 빠를까?"라는 생각이 듭니다. 병합 정렬의 시간 복잡도를 계산해 봅시다. 입력의 크기가 n인 병합 정렬의 시간 복잡도 $T(n)$은 다음과 같은 점화식으로 나타낼 수 있습니다.
$$T(n)=2T(\frac{n}{2})+\Theta(n)$$
$\Theta(n)$은 병합(merge 함수)에서 소모되는 시간입니다. 위 점화식을 마스터 정리로 계산하면 병합 정렬의 수행 시간은 최악의 경우에도 $\Theta(n\log n)$입니다.
구현
- C++
void merge(int arr[], int start, int mid, int end) {
int i = start, j = mid + 1, k = 0;
int tmp[end - start + 1];
while(i <= mid and j <= end)
tmp[k++] = arr[i] < arr[j] ? arr[i++] : arr[j++];
while(i <= mid)
tmp[k++] = arr[i++];
while(j <= end)
tmp[k++] = arr[j++];
i = start; k = 0;
while(i <= end)
arr[i++] = tmp[k++];
}
void merge_sort(int arr[], int start, int end) {
if(start < end) {
int mid = (start + end) / 2;
merge_sort(arr, start, mid);
merge_sort(arr, mid+1, end);
merge(arr, start, mid, end);
}
}
- Java
class Sort {
public static void mergeSort(int[] arr, int first, int last){
if(first < last){
int mid = (first + last) / 2;
mergeSort(arr, first, mid);
mergeSort(arr, mid+1, last);
merge(arr, first, mid, last);
}
}
private static void merge(int[] arr, int first, int mid, int last){
int i=first, j=mid+1, k=0;
int[] tmp = new int[last - first + 1];
while(i <= mid && j <= last)
tmp[k++] = (arr[i] < arr[j]) ? arr[i++] : arr[j++];
while(i <= mid)
tmp[k++] = arr[i++];
while(j <= last)
tmp[k++] = arr[j++];
i=first; k=0;
while(i <= last)
arr[i++] = tmp[k++];
}
}
last는 마지막 인덱스입니다. 초기값으로 arr.length-1을 전달해줘야 합니다. 메서드로 한번 더 감싸거나 mergeSort와 merge가 새로운 배열을 만들어 반환하도록 하면 인자로 배열만 받는 mergeSort를 구현할 수도 있습니다.
- Python
def merge_sort(lst, start, last):
if start < last:
mid = (start + last) // 2
merge_sort(lst, start, mid)
merge_sort(lst, mid+1, last)
merge(lst, start, mid, last)
def merge(lst, start, mid, last):
left, left_end = 0, mid - start
right, right_end = left_end+1, last - start
tmp = lst[start:last+1]
for k in range(start, last+1):
if left > left_end:
lst[k] = tmp[right]
right += 1
elif right > right_end:
lst[k] = tmp[left]
left += 1
elif tmp[left] < tmp[right]:
lst[k] = tmp[left]
left += 1
else:
lst[k] = tmp[right]
right += 1
python 코드에서는 merge 함수를 다르게 작성해 봤습니다.
특징
병합 정렬은 모든 경우에서 $\Theta(n\log n)$의 시간 복잡도를 갖는 빠르고 안정적인 정렬 알고리즘입니다. 위에서 구현한 것처럼 배열을 사용할 경우 입력에 크기와 같은 크기의 임시 배열을 생성해야 하고 대입연산이 많아 다른 $\Omicron(n\log n)$ 정렬 알고리즘에 비해 오버헤드가 큽니다. 배열 대신 연결 리스트(Linked List)를 사용하면 데이터이동이 효율적이고 `제자리 정렬(in-place sorting)`로 구현이 가능합니다.
'Computer Science > Algorithm' 카테고리의 다른 글
요세푸스 문제(Josephus problem) (0) | 2023.09.25 |
---|---|
[Algorithm] 퀵 정렬(Quick Sort) - C++, Java, Python 구현 (0) | 2023.08.25 |
[Algorithm] 버블 정렬(Bubble Sort) - C++, Java, Python 구현 (0) | 2023.08.22 |
[Algorithm] 삽입 정렬(Insertion Sort) - C++, Java, Python 구현 (0) | 2023.08.22 |
[Algorithm] 선택 정렬(Selection Sort) - C++, Java, Python 구현 (0) | 2023.08.21 |