BOJ 1517 버블 소트

LONGNEW·2021년 1월 21일
0

BOJ

목록 보기
77/333

https://www.acmicpc.net/problem/1517
시간 1초, 메모리 512MB
input :

  • N(1≤N≤500,000)
  • Ai

output :

  • Swap 횟수를 출력

조건 :

  • 버블 소트는 서로 인접해 있는 두 수를 바꿔가며 정렬하는 방법

어제 쉬고 해서 그런지 왜 이리 오래 걸렸는지 모르겠다. ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ 일단 버블 소트를 이용하면 당연히 시간 초과가 발생한다.

그래서 대부분의 경우 병합정렬이나, 세그 트리를 이용한다고 한다.

병합정렬의 경우. 입력을 받은 리스트 둘 중 하나를 기준으로 잡는다.
병합 정렬이 수행 될 때, 이제 merge 할 리스트들은 정렬이 되어서 올라온다.
그럴 경우 이미 왼쪽에 있던 애들이 오른쪽에 있는 애들보다 숫자가 커서 swap해 주는 것을 기록해야 한다.
그러면 왼쪽에 있는 애들이 들어가기 전에 이미 정렬이 된 개수를 세아려서 swap에 넣어주자.
그리고 이 cnt(오른쪽 애들이 정렬된 개수)는 초기화를 시키지 않는다. 모든 왼쪽 리스트에 대하여 적용이 되어야 하기 때문에 누적이 되어야 한다.

그리고 가장 고민 했던 것은 크기가 같은 숫자이면 정렬을 어떻게 하는가? 였는데
그냥.. 같은 숫자이면 스왑을 안 하기 때문에 왼쪽에 존재하는 숫자를 new_arr에 집어넣고 다시 merge를 진행하면 된다. 즉 따로 생각할 필요가 없다 ......

import sys

input = sys.stdin.readline
sys.setrecursionlimit(10 ** 9)

def merge(start, end):
    # merge
    global swap
    new_arr = []
    mid = (start + end) // 2
    l_idx, r_idx = start, mid
    cnt = 0
    while l_idx < mid and r_idx < end:
        if arr[l_idx] > arr[r_idx]:
            new_arr.append(arr[r_idx])
            r_idx += 1
            cnt += 1
        else:  # arr[idx1] < arr[idx2]
            new_arr.append(arr[l_idx])
            l_idx += 1
            swap += cnt

    while l_idx < mid:
        new_arr.append(arr[l_idx])
        l_idx += 1
        swap += cnt
    while r_idx < end:
        new_arr.append(arr[r_idx])
        r_idx += 1

    # reflect
    for t in range(len(new_arr)):
        arr[start + t] = new_arr[t]

def merge_sort(start, end):
    global swap, arr
    size = end - start
    mid = (start + end) // 2
    if size <= 1:
        return

    # divide
    merge_sort(start, mid)
    merge_sort(mid, end)
    merge(start, end)

n = int(input())
arr = list(map(int, input().split()))
swap = 0
merge_sort(0, n)
print(swap)

배열의 인덱스를 이용했기 때문에 배열은 업데이트를 해줘야 정렬이 된 모양을 가지게 된다.

import sys
sys.setrecursionlimit(10 ** 9)

def merge_sort(start, end):
    if start + 1 >= end:
        return

    mid = (end + start) // 2
    merge_sort(start, mid)
    merge_sort(mid, end)

    merge(start, end)

def merge(left, right):
    global cnt
    mid = (left + right) // 2
    i, j, ret, small_cnt = left, mid, [], 0

    while i < mid and j < right:
        if data[i] > data[j]:
            ret.append(data[j])
            j += 1
            small_cnt += 1
        else:
            ret.append(data[i])
            i += 1
            cnt += small_cnt

    if i == mid:
        ret += data[j:right]
    else:
        ret += data[i:mid]
        cnt += (mid - i) * small_cnt

    for i in range(len(ret)):
        data[left + i] = ret[i]

n = int(sys.stdin.readline())
data = list(map(int, sys.stdin.readline().split()))
cnt = 0

merge_sort(0, n)
print(cnt)

0개의 댓글