WEEK. 01 2022.04.04 TIL

이진호·2022년 4월 5일
1

TIL

목록 보기
1/11

퀵 정렬

def pivot_sort(a, left, right): # 배열, 정렬 대상의 양쪽 끝 값
    n = len(a)
    pl = left
    pr = right
    pivot = a[(pl+pr)//2]
    while pl <= pr: 
        while a[pl] < pivot: pl += 1
        while a[pr] > pivot: pr -= 1    
        if pl <= pr:
            a[pl], a[pr] = a[pr], a[pl]
            pl += 1
            pr -= 1
    # 종료조건, left 및 right가 pr, pl에 도달하면 정렬이 완료되었다고 판단함.
    if pr > left: pivot_sort(a, left, pr+1)
    if pl < right: pivot_sort(a, pl, right)

스택

스택은 데이터를 임시 저장할 때 사용하는 자료구조로, 데이터의 입력과 출력 순서는 후입선출(LIFO) 방식입니다.

LIFO(ast in first out)란 가장 나중에 넣은 데이터를 가장 먼저 꺼낸다는 의미.

스택 배열: stk

  • 푸시한 데이터를 저장하는 스택 본체인 list형 배열, 인덱스가 0인 원소를 스택의 바닥이라 함.

스택 크기: capacity

  • 스택의 최대 크기를 나타내는 int형 정수이고, 배열 stk의 원소 수인 len(stk)와 일치함.

스택 포인터: ptr

  • 스택에 쌓여있는 데이터의 개수를 나타내는 정숫값으로 비어있으면 0 가득 차 있으면 capacity와 같은 값.

스택을 이용한 비재귀적 퀵 정렬

정렬을 해야 할 원소 범위를 스택에 푸시하고, 스택이 비어있으면(정렬할 범위가 남아있지 않음) while문을 종료하도록 작성.

 range = Stack(right-left+1) # stack 생성
 range.push((left, right))
  while not range.is_empty(): # stack이 비어있으면 정렬 종료
      pl, pr = left, right = range.pop() # stack의 top에 위치한 정렬 범위를 의미하는 원소를 추출
      x = a[(left + right) // 2] # 피봇값

      # 퀵 정렬 알고리즘

      if left < pr: range.push((left, pr)) # pr이 left에 도달하면 정렬이 종료하므로 stack이 비어있게 되어 정렬 종료
      if pl < right: range.push((pl, right))

배열을 스택에 푸시할 때, 원소 수가 많은 쪽을 먼저 푸시(나중에 pop 되도록)하는 것이 스택의 크기를 줄이는 데 효과적임.

피벗 선택 시 한쪽으로 치우친 값을 선택하면 효율이 떨어질 수 있음. 이를 해결하기 위해 배열에서 임의의 원소 3개를 꺼내 중앙값인 원소를 피벗으로 선택하면 좋음.

퀵 정렬 - 피벗 선택 방법 수정

피벗을 효과적으로 선택하기 위해 배열의 첫번째, 중앙값, 마지막 값을 추출하여 정렬시킨 후 피벗을 선택하도록 코드를 수정함. 이 방법까지 백준 2751번 수 정렬하기 문제의 경우 시간초과로 풀지 못했음.

def selection_mid(a, b, c): # 세 값을 정렬하는 함수
    if b > c: b, c = c, b
    if a > b: a, b = b, a
    if b > c: b, c = c, b
    return a, b, c

def quick_sort(a, left, right):
    pv_index = (left + right) // 2
    a[left], a[pv_index], a[right] = selection_mid(a[left], a[pv_index], a[right])
    a[pv_index], a[right-1] = a[right-1], a[pv_index] # pivot에 위치하는 값을 right-1 위치의 원소와 교환
    pl = left + 1 # 첫번째 원소는 이미 정렬을 완료했으므로 pl 범위가 수정됨.
    pr = right - 2 # right 및 right-1 원소 또한 정렬을 완료했으므로 pr값이 수정됨.
    pivot = a[right-1]
    while pl <= pr:
        while a[pl] < pivot: pl += 1
        while a[pr] > pivot: pr -= 1
        if pl <= pr:
            a[pl], a[pr] = a[pr], a[pl]
            pl += 1
            pr -= 1
    if pr > left: quick_sort(a, left, pr+1)
    if pl < right: quick_sort(a, pl, right)

병합 정렬

배열을 특정 기준까지 나누고, 나눈 배열들을 정렬한 후 병합하는 방법. 잘 이해가 안가서 다음과 같이 적어봄.

아래는 생각한 것을 토대로 의식의 흐름에 따라 한번 작성해본 코드인데 정체를 알 수 없는 에러가 엄청 많이 발생했다.

def merge_sort(a, left, right):
    center = (left + right) // 2
    if right > left: # 배열의 끝 인덱스가 첫번째 인덱스와 같아지는 즉, 배열의 크기가 1이 될때까지 실행
        merge_sort(a, left, center) # 배열 분류
        merge_sort(a, center, right)
        
    # 병합 부분
    a1 = a[left:center] # left부터 center 전까지
    a2 = a[center:right+1] # center부터 right까지
    n1 = len(a1) # a1 배열 크기
    n2 = len(a2) # a2 배열 크기
    temp_3 = [] # a1과 a2를 정렬하기 위한 임시 배열
    i = 0 # a1 인덱스
    j = 0 # a2 인덱스
    while i < n1 and j < n2: # a1 혹은 a2의 인덱스가 배열의 끝에 도달하면 종료
        if a1[i] >= a2[j]:
            temp_3 += a1[i]
            i += 1
        elif a1[i] < a2[j]:
            temp_3 += a2[j]
            j += 1
    if i < n1: # 종료 시점에서의 인덱스가 배열의 크기보다 작을 경우 남은 원소들을 temp 배열에 추가
        temp_3 += a1[i:n1]
    if j < n2:
        temp_3 += a2[j:n2]
    a[left:right+1] = temp_3 # 타겟 배열 범위를 temp로 수정
    return

아래는 수정한 코드.
우선 merge_sort를 탈출하는 조건이 원소가 하나가 됐을 경우인데 이때는 병합을 할 수 없어서 오류가 발생했다. 이를 수정해줌.
또한, 현재 내가 적용한 방법은 예를 들어 원소 개수가 두개일 때 두개를 a1, a2로 나누어 작은 수 부터 temp라는 임시 배열에 넣고, 이를 본래 배열인 a의 해당 위치로 넣어주는 식이다 하지만 center 값을 left+right의 몫으로 설정했기 때문에 left와 center 값이 동일해진다.
이 경우, a1[left:center]에서 left와 center가 동일해져 해당 배열이 비어버려 오류가 발생함. 이를 수정해줌.

import sys
sys.setrecursionlimit(10**5)
n = int(input())
a = [int(sys.stdin.readline()) for i in range(n)] 

def merge_sort(a, left, right):
    center = (left + right) // 2
    if right > left:
        merge_sort(a, left, center)
        merge_sort(a, center+1, right)
        a1 = a[left:center+1]
        a2 = a[center+1:right+1]
        n1 = len(a1)
        n2 = len(a2)
        temp_3 = []
        i = 0
        j = 0
        while i < n1 and j < n2:
            if a1[i] >= a2[j]:
                temp_3.append(a2[j])
                j += 1
            elif a1[i] < a2[j]:
                temp_3.append(a1[i])
                i += 1
        if i < n1:
            temp_3 += a1[i:n1]
        if j < n2:
            temp_3 += a2[j:n2]
        a[left:right+1] = temp_3
        
merge_sort(a, 0, len(a)-1)
for x in a:
    print(x)

0개의 댓글