SWEA_분할정복, 백트래킹

김병훈·2021년 4월 20일
0

5204_병합정렬

내 코드

알게된 점

  • 아주 느린 pop(0)

    merge함수에서 pop(0)을 이용하여, left와 right를 앞에서부터 하나씩 빼주었었다.
    이 방식이 생각보다 시간이 굉장히 많이 소요되는 것 같다.
    pop(0) 대신에 각각의 배열에 인덱스를 별도로 선언하여, pop(0)이 일어나는 위치에 index를 증가시켰다.
    index를 다루는 것이 실수가 나올 수 있는 여지가 있다고 생각하지만, 시간이 더 짧게 걸린다는 것을 알 수 있었다.

  • 병합정렬

    1. 나눌 수 있을 때 까지 2개의 배열로 나눈다. (merge_sort)
    2. 나눈 배열을 합친다. (merge)

Pass

def merge(left, right):
    global answer
    result = []
    left_idx = right_idx = 0
    left_length = len(left)
    right_length = len(right)
    # 하나씩 빼서 result에 넣어주는 방식
    while left_idx < left_length or right_idx < right_length:
        if left_idx < left_length and right_idx < right_length:
            # 맨 앞에 있는 값부터 비교한다.
            if left[left_idx] <= right[right_idx]:
                result.append(left[left_idx])
                left_idx += 1
            else:
                result.append(right[right_idx])
                right_idx += 1
        elif right_idx == right_length:
            result.append(left[left_idx])
            left_idx += 1
        elif left_length == left_length:
            result.append(right[right_idx])
            right_idx += 1
    if left[-1] > right[-1]:
        answer += 1
    return result

def merge_sort(arr, arr_length):
    # 리스트를 분할하는 과정
    # 1. 전체를 두 부분으로 나눈 후, 다시 나누는 과정을 진행한다. (merge_sort)
    # 2. 길이가 1이 될 때까지 나누었다면, 병합한다. (merge)

    # 길이가 1이라면, 더이상 나눌 것이 없으므로 리턴한다.
    if arr_length == 1:
        return arr

    mid = arr_length // 2
    left = arr[:mid]
    right = arr[mid:]

    left = merge_sort(left, len(left))
    right = merge_sort(right, len(right))
    return merge(left, right)

T = int(input())

for tc in range(1, T + 1):
    N = int(input())
    NUMS = list(map(int, input().strip().split()))
    answer = 0
    sorted_arr = merge_sort(NUMS, N)
    print(f"#{tc}", sorted_arr[N // 2], answer)

시간초과

pop(0)은 생각보다 많은 시간이 소요된다.

def merge(left, right):
    global answer
    result = []
    # 하나씩 빼서 result에 넣어주는 방식
    while len(left) > 0 or len(right) > 0:
        left_length = len(left)
        right_length = len(right)
        if left_length > 0 and right_length > 0:
            if left[0] <= right[0]:
                result.append(left.pop(0))
            else:
                result.append(right.pop(0))
        elif right_length == 0:
            if left_length == 1:
                answer += 1
            result.append(left.pop(0))
        elif left_length == 0:
            result.append(right.pop(0))
    return result

def merge_sort(arr, arr_length):
    # 리스트를 분할하는 과정
    if arr_length == 1:
        return arr

    left = []
    right = []
    mid = arr_length // 2
    for x in arr[:mid]:
        left.append(x)
    for x in arr[mid:]:
        right.append(x)

    left = merge_sort(left, len(left))
    right = merge_sort(right, len(right))
    return merge(left, right)

T = int(input())

for tc in range(1, T + 1):
    N = int(input())
    NUMS = list(map(int, input().split()))
    answer = 0
    sorted_arr = merge_sort(NUMS, N)
    # print(sorted_arr)
    print(f"#{tc}", sorted_arr[N // 2], answer)

5205_퀵 정렬

내 코드

알게된 점

  • 퀵 정렬

    1. 정렬할 배열과 왼쪽 / 오른쪽 인덱스를 이용해 파티션을 만든다. (hoare_partition)
    2. 만든 파티션을 기준으로 왼쪽 배열과 오른쪽 배열에 대해 다시 퀵 정렬을 수행한다. (quick_sort)
  • 파티션을 만드는 방법

    1. 가장 왼쪽 값을 기준 값으로 설정한다.

    2. 왼쪽 인덱스는 오른쪽 방향으로 "기준 값 보다 큰 값을 찾는" 탐색을 진행한다.

    3. 오른쪽 인덱스는 왼쪽 방향으로 "기준 값 보다 작은 값을 찾는" 탐색을 진행한다.

    4. 찾은 위치가 반전되지 않았다면, 왼쪽 인덱스와 오른쪽 인덱스의 값을 서로 교환한다.

      큰 값은 오른쪽으로 / 작은 값은 왼쪽으로 보내기 위해

    5. 탐색이 종료되었다면, 오른쪽 인덱스가 가장 마지막으로 찾은 값과 기준 값(가장 왼쪽 값)을 교환한다.

      오른쪽 인덱스가 찾은 값은 항상 기준 값 보다 작으며, 탐색을 진행하는 동안 기준 값보다 작은 값들은 현재 오른쪽 인덱스보다 왼쪽에 위치함 (4번 과정에 의해)

    6. 기준 값의 위치를 반환한다. 이 인덱스가 파티션이 된다.

Pass

# 퀵 정렬
# 주어진 배열을 두 개로 분할하고, 각각을 정렬한다.

def hoare_partition(arr, left_idx, right_idx):
    # 가장 왼쪽 값을 기준 값으로 설정
    pivot_value = arr[left_idx]
    # left와 right가 반전될 때 까지 교환을 위한 탐색을 진행한다.
    l_idx = left_idx
    r_idx = right_idx

    while l_idx < r_idx:
        # l_idx는 pivot_value보다 큰 값을 찾는다. (작거나 같은 값이면, 다음 칸으로 넘어간다.)
        # 인덱스를 벗어나지 않는다.
        while left_idx <= l_idx < right_idx and arr[l_idx] <= pivot_value:
            l_idx += 1
        # r_idx는 pivot_value보다 작은 값을 찾는다. (작은 값을 찾을 때까지 멈추지 않는다.)
        while left_idx < r_idx <= right_idx and arr[r_idx] >= pivot_value:
            r_idx -= 1
        # 두 탐색을 거친 후, 각각 값을 찾았다면 교환을 진행한다.
        if l_idx < r_idx:
            arr[l_idx], arr[r_idx] = arr[r_idx], arr[l_idx]
    # 모든 탐색이 끝난 뒤에 기준 값을 마지막에 찾은 자신보다 작은 값과 교환한다. (구분을 위해)
    arr[r_idx], arr[left_idx] = arr[left_idx], arr[r_idx]
    # pivot_value가 위치한 인덱스를 반환한다.
    return r_idx


def quick_sort(arr, left, right):
    # 주어진 arr, left, right를 이용해 기준 파티션을 만들고
    # 파티션을 기준으로 배열을 분할하여 각각을 정렬한다.
    # arr의 길이가 1인 경우, left와 right가 반전될 수 있다.
    if left < right:
        p = hoare_partition(arr, left, right)
        quick_sort(arr, left, p - 1)        # 파티션 기준 왼쪽 배열을 정렬
        quick_sort(arr, p + 1, right)       # 파티션 기준 오른쪽 배열을 정렬


T = int(input())

for tc in range(1, T + 1):
    # N: 정수의 개수
    N = int(input())

    NUMS = list(map(int, input().split()))
    quick_sort(NUMS, 0, N - 1)
    print(f"#{tc} {NUMS[N // 2]}")

5207_이진 탐색

내 코드

알게 된 점

  • 이진 탐색을 수행하기 위한 필수 조건
    탐색할 배열은 정렬되어 있어야 한다.

Pass!!!

A 배열이 정렬이 안되어 있을 줄은 몰랐다...

def binary_search(arr, left, right, find_num):
    global flag
    start = left
    end = right
    if left > right:
        return False
    mid = (left + right) // 2
    if find_num >= arr[mid]:
        if arr[mid] == find_num:
            return True
        if flag == "R":
            return False
        flag = "R"
        start = mid + 1
    elif arr[mid] > find_num:
        if flag == "L":
            return False
        flag = "L"
        end = mid - 1
    return binary_search(arr, start, end, find_num)


T = int(input())

for tc in range(1, T + 1):
    # N: A의 개수
    # M: B의 개수
    N, M = map(int, input().split())

    A = sorted(list(map(int, input().split())))
    B = list(map(int, input().split()))

    answer = 0

    for num in B:
        flag = 0
        if binary_search(A, 0, N - 1, num):
            answer += 1
    print(f"#{tc} {answer}")

오답!

def binary_search(arr, left, right, find_num):
    if left > right:
        return False
    mid = (left + right) // 2
    if arr[mid] == find_num:
        return True
    elif arr[mid] > find_num:
        if not step:
            step.append("L")
        else:
            if step[-1] == "L":
                return False
            else:
                step.append("L")
        return binary_search(arr, left, mid - 1, find_num)
    else:
        if not step:
            step.append("R")
        else:
            if step[-1] == "R":
                return False
            else:
                step.append("R")
        return binary_search(arr, mid + 1, right, find_num)


T = int(input())

for tc in range(1, T + 1):
    # N: A의 개수
    # M: B의 개수
    N, M = map(int, input().split())

    A = list(map(int, input().split()))
    B = list(map(int, input().split()))

    answer = 0

    for num in B:
        step = []
        # flag = 0
        if binary_search(A, 0, N - 1, num):
            answer += 1
    print(f"#{tc} {answer}")

오답 2

step 대신 flag 사용 (코드가 한층 간결해짐)

def binary_search(arr, left, right, find_num):
    global flag
    if left > right:
        return False
    mid = (left + right) // 2
    if arr[mid] == find_num:
        return True
    elif arr[mid] > find_num:
        if flag == "L":
            return False
        flag = "L"
        return binary_search(arr, left, mid - 1, find_num)
    else:
        if flag == "R":
            return False
        flag = "R"
        return binary_search(arr, mid + 1, right, find_num)


T = int(input())

for tc in range(1, T + 1):
    # N: A의 개수
    # M: B의 개수
    N, M = map(int, input().split())

    A = list(map(int, input().split()))
    B = list(map(int, input().split()))

    answer = 0

    for num in B:
        flag = 0
        if binary_search(A, 0, N - 1, num):
            answer += 1
    print(f"#{tc} {answer}")

5208_전기버스2

내 코드

알게 된 것

  • 이중 반복문은 생각보다 연산횟수가 많아지는 듯 하다.

Pass!!

  • 최대로 갈 수 있는 정류장을 별도의 인자로 관리하여, 충전하지 않아도 되는 경우를 판단할 때 사용하였다.
def solve(current, cnt, max_stop):
    global min_cnt
    # 충전 횟수가 최소 횟수만큼 쌓였다면, 탐색할 필요가 없음
    if cnt >= min_cnt:
        return
    
    # 정류장에 도달했다면, 최소 횟수만에 도착한 것
    if current == N - 1:
        min_cnt = cnt
        return
    
    # 이번 충전소에서 충전지 교체
    visited[current] = True
    solve(current + 1, cnt + 1, max(max_stop, current + M[current]))
    # 이번 충전소에서 충전지를 교체하지 않음
    # 교체하지 않아도 괜찮나?
    visited[current] = False
    # 다음 충전소를 갈 수 있을까?
    if max_stop >= current + 1:
        solve(current + 1, cnt, max_stop)


# 방전되지 않아도 충전지를 교체할 수 있다.
T = int(input())

for tc in range(1, T + 1):
    # M: 각 정류장 별 충전지 용량
    N, *M = list(map(int, input().split()))
    min_cnt = N

    # 충전지를 교체하면 갈 수 있는 배열이 변경된다.
        # 갈 수 있는 거리를 표시하는 배열을 만들어야겠다.
    current = 0
    visited = [False] * (N - 1)
    can_go = [False] * N
    # 첫 번째 정류장에서 충전지를 받은 것은 카운트하지 않는다.
    visited[0] = True
    # 두 번째 정류장부터 시작
    # 최대로 갈 수 있는 정류장은 0 + M[0]
    solve(1, 0, M[0])
    print(f"#{tc} {min_cnt}")

시간초과

  • can_go 부분 때문에 발생하는 건가?
def solve(current, cnt):
    global min_cnt
    # print(current, cnt)
    if cnt >= min_cnt:
        return
    
    if current == N - 1:
        min_cnt = cnt
        return
    
    # 이번 충전소에서 충전지 교체
    visited[current] = True
    solve(current + 1, cnt + 1)
    # 이번 충전소에서 충전지를 교체하지 않음
    # 교체하지 않아도 괜찮나?
    visited[current] = False
    # 다음 충전소를 갈 수 있을까?
    can_go = [False] * N
    for i in range(current - 1, -1, -1):
        if visited[i]:
            for j in range(i, min(i + M[i] + 1, N)):
                can_go[j] = True
        if can_go[current + 1]:
            break
    else:
        return
    solve(current + 1, cnt)


# 방전되지 않아도 충전지를 교체할 수 있다.
T = int(input())

for tc in range(1, T + 1):
    # M: 각 정류장 별 충전지 용량
    N, *M = list(map(int, input().split()))
    min_cnt = N

    # 충전지를 교체하면 갈 수 있는 배열이 변경된다.
        # 갈 수 있는 거리를 표시하는 배열을 만들어야겠다.
    current = 0
    visited = [False] * (N - 1)
    can_go = [False] * N
    # 첫 번째 정류장에서 충전지를 받은 것은 카운트하지 않는다.
    visited[0] = True
    # 두 번째 정류장부터 시작
    solve(1, 0)
    print(f"#{tc} {min_cnt}")

5209_최소 생산 비용

  • 2차원 배열을 이용해 순열을 생성하여, 최소 생산 비용을 찾아내는 문제

내 코드

Pass!

def solve(product_idx, cost):
    global min_cost
    if cost >= min_cost:
        return

    if product_idx == N:
        min_cost = cost

    # 이번 제품을 어떤 공장에서 생산할까요?
    for i in range(N):
        if visited[i]:
            continue
        # i 공장에서 생산하는 루트
        visited[i] = True
        solve(product_idx + 1, cost + data[product_idx][i])
        # i 공장에서 생산하지 않는 루트
        visited[i] = False

T = int(input())

for tc in range(1, T + 1):
    N = int(input())
    data = [list(map(int, input().split())) for _ in range(N)]

    # 열 단위로 순회하여, 어떤 행을 선택하는 지
    # index: 열(제품) 0 ~ N-1
    # value: 행(공장) 0 ~ N-1
    visited = [False] * N
    min_cost = 100 * N
    solve(0, 0)
    print(f"#{tc} {min_cost}")
profile
재밌는 걸 만드는 것을 좋아하는 메이커

1개의 댓글

comment-user-thumbnail
2021년 8월 9일

;

답글 달기