Geeksforgeeks에서는 꼬리재귀함수를 다음과 같이 정의한다:

Tail recursion is defined as a recursive function in which the recursive call is the last statement that is executed by the function. So basically nothing is left to execute after the recursion call.

즉, 함수가 실행되고 나서 반환되는 값에 연산이 되지 않으면 꼬리재귀함수라고 볼 수 있다. 예를 들어, 팩토리얼 함수를 아래와 같이 구현할 수 있다.

def factorial(n: int) -> int:
	if n == 1:
    	return 1
	return n * factorial(n-1)

위 함수는 반환값에 factorial 함수를 호출하고 나서 n을 곱해야 하므로 꼬리함수가 아니다. 반면 아래 함수는 factorial 함수를 호출하고 나서 바로 반환하므로 꼬리함수라고 볼 수 있다.

def factorial(n: int, total: int = 1) -> int:
	if n == 1:
    	return total
    return factorial(n-1, n*total)

그런데 컴파일 언어에서는 컴파일 시 꼬리재귀함수를 반복문으로 바꾸는 성능최적화가 적용된다고 한다. 하지만 파이썬은 인터프리터 언어라 그런 호사는 누릴 수 없다. 직접 바꿔보도록 하자.

꼬리재귀함수 반복문으로 바꾸기

이번에 백준 16928. 뱀과 사다리 게임 문제를 풀면서 꼬리재귀함수를 사용할 기회가 있었다. 간략하게 설명하면, 주사위를 던져 나온 칸만큼 진행해 1부터 100까지 가야되는 상황에서, 사다리나 뱀을 타고 다른 곳으로 이동하는 게임이다. 사다리나 뱀이 연결된 칸에서는 방문했던 하지 않았던 무조건 이동해야된다. 이를 그래프로 구현하기 위해 아래와 같이 get_destination 함수를 작성했다.

_, _, *ls = map(int, open(0).read().split())

snl = [0] * 101
for start, dest in zip(ls[::2], ls[1::2]):
    snl[start] = dest

def get_destination(start: int) -> int:
    if snl[start]:
    	return get_destination(snl[start])
    return start

graph = [[(get_destination(j), 1) for j in range(i+1, i+7) if j < 101] for i in range(101)]

먼저. snl 리스트에 뱀과 사다리 정보를 넣는다. i번 칸에 사다리나 뱀이 가리키는 도착지 정보를 넣는 식이다. 그 정보들을 사용해 get_destination 함수는 도착지를 추적한다. 예를 들어 아래와 같은 정보가 있다고 하자:

1 3
3 5
5 4

그렇다면 1번이나 3번 칸의 도착지는 4번이어야하고, 5번 칸의 도착지는 그대로 4번이다. 1, 3, 5번 인덱스에 각각 3, 5, 4를 추가하고 위 함수를 호출하면, 1, 3, 5번칸에서 4번을 리턴한다.
반환 되는 값에 추가적인 연산이 되지 않으므로 꼬리함수라고 볼 수 있는 것 같다. while문을 사용해 이렇게 바꿀 수 있다.

def get_destination(start: int) -> int:
	result = start
    while (r := snl[result]):
    	result = r
    return result

과연 성능이 올라갔을까?

다익스트라 알고리즘

백준 기준 52ms에서 48ms이지만 약간(약 8%) 개선되었음을 확인할 수 있다. 100칸짜리 문제라 큰 차이는 아니지만 개선되긴 했다.

플로이드 워셜 알고리즘

플로이드 워셜은 어떨까? 플로이드 워셜 또한 496ms에서 468ms로 성능이 개선되었다(약 6%).

나가면서

아래는 사용한 코드이다. 개인적으로 재귀함수의 매력에 빠져있는 터라 재귀함수를 사용한 코드를 올리고 글을 마친다.

# 데이크스트라
from heapq import *
_, _, *ls = map(int, open(0).read().split())

snl = [0] * 101
for start, dest in zip(ls[::2], ls[1::2]):
    snl[start] = dest

def get_destination(start: int) -> int:
    if snl[start]:
    	return get_destination(snl[start])
    return start
    

graph = [[(get_destination(j), 1) for j in range(i+1, i+7) if j < 101] for i in range(101)]

def dijkstra(start: int, end: int) -> int:
    heap = [(0, start)]
    visited = set()
    while heap:
        cost, current = heappop(heap)  # 시작점에서 현재노드까지 가는 비용, 현재노드
        if current in visited:
            continue
        visited.add(current)
        if current == end:   # 현재 노드가 도착점이라면 비용 반환
            return cost
        for v, c in graph[current]:
            if v in visited:
                continue
            cost_to_proceed = cost + c   # 시작점에서 현재 노드까지 오는 비용 + 현재 노드에서 다음 노드로 가는 비용
            heappush(heap, (cost_to_proceed, v))

print(dijkstra(1, 100))


# 플로이드 워셜
from math import inf
_, _, *ls = map(int, open(0).read().split())

snl = [0] * 101
for start, dest in zip(ls[::2], ls[1::2]):
    snl[start] = dest

def get_destination(start: int) -> int:
    if snl[start]:
    	return get_destination(snl[start])
    return start

board = [[inf] * 101 for _ in range(101)]
for i in range(101):
    board[i][i] = 0
    for j in range(i+1, i+7):
        if j < 101:
            board[i][get_destination(j)] = 1

for k in range(1, 101):
    for a in range(1, 101):
        for b in range(1, 101):
            board[a][b] = min(board[a][b], board[a][k] + board[k][b])

print(board[1][100])

수정사항

2 2
3 99
10 44
99 10
44 19

위와 같은 예를 고려하지 않아서 틀렸다고 생각해 get_destination 함수를 구현한거였는데 다시보니 아니었다. 아래와 같은 코드로도 '맞았습니다!!!'를 받을 수 있었다.

# 데이크스트라
from heapq import *
_, _, *ls = map(int, open(0).read().split())

snl = [0] * 101
for i, j in zip(ls[::2], ls[1::2]):
    snl[i] = j

graph = [[(snl[j] if snl[j] else j, 1) for j in range(i+1, i+7) if j < 101] for i in range(101)]

def dijkstra(start: int, end: int) -> int:
    heap = [(0, start)]
    visited = set()
    while heap:
        cost, current = heappop(heap)  # 시작점에서 현재노드까지 가는 비용, 현재노드
        if current in visited:
            continue
        visited.add(current)
        if current == end:   # 현재 노드가 도착점이라면 비용 반환
            return cost
        for v, c in graph[current]:
            if v in visited:
                continue
            cost_to_proceed = cost + c   # 시작점에서 현재 노드까지 오는 비용 + 현재 노드에서 다음 노드로 가는 비용
            heappush(heap, (cost_to_proceed, v))

print(dijkstra(1, 100))
profile
이토록 멋진 휴식!

0개의 댓글