트리에서 리프 노드란, 자식의 개수가 0인 노드를 말한다.
트리가 주어졌을 때, 노드 하나를 지울 것이다. 그 때, 남은 트리에서 리프 노드의 개수를 구하는 프로그램을 작성하시오. 노드를 지우면 그 노드와 노드의 모든 자손이 트리에서 제거된다.
예를 들어, 다음과 같은 트리가 있다고 하자.
현재 리프 노드의 개수는 3개이다. (초록색 색칠된 노드) 이때, 1번을 지우면, 다음과 같이 변한다. 검정색으로 색칠된 노드가 트리에서 제거된 노드이다.
이제 리프 노드의 개수는 1개이다.
첫째 줄에 트리의 노드의 개수 N이 주어진다. N은 50보다 작거나 같은 자연수이다. 둘째 줄에는 0번 노드부터 N-1번 노드까지, 각 노드의 부모가 주어진다. 만약 부모가 없다면 (루트) -1이 주어진다. 셋째 줄에는 지울 노드의 번호가 주어진다.
첫째 줄에 입력으로 주어진 트리에서 입력으로 주어진 노드를 지웠을 때, 리프 노드의 개수를 출력한다.
5
-1 0 0 1 1
2
2
5
-1 0 0 1 1
1
2
5
-1 0 0 1 1
0
0
9
-1 0 0 2 2 4 4 6 6
4
2
문제에서 요구하는 트리구조를 표현하기 위해 각 노드들을 배열에 저장해 트리구조를 생성하기로 정했다.이때, 예제에서 봤다시피 각 노드의 부모를 입력값으로 받고 있다.
n = int(input()) #노드의 개수
arr = list(map(int, input().split())) #각 노드의 부모 배열
k = int(input()) #삭제할 노드
각 노드의 부모 배열을 DFS로 탐색하며, 지워야하는 노드와 해당 노드가 부모 노드인 자식 노드들의 값들을 모두 특정 값으로 변환해야 한다.왜? 특정값으로 변환한 노드들을 제외하고 나머지 노드들이 곧 리프노드이기 떄문이다. 즉, 이러한 방식으로 리프노드를 구별할 수 있다.
난 여기서 특정값을 -2로 정했다.
이후, 각 노드의 부모 배열을 다시 한 번 탐색하며, 값이 -2 가 아니며, 해당 노드를 부모로 하는 노드가 부모 배열에 없을 경우, 리프 노드의 개수를 +1씩 한다.
이러한 방식으로 리프노드의 개수를 뽑아낼 수 있다.
일반적으로 모든 정점을 방문하는 것이 중요한 문제는 DFS 또는 BFS를 사용한다. 이중 검색 대상 그래프가 크겨나 경로의 특징을 저장해둬야 하는 문제는 DFS가 적합하다.
왜일까??
1. 스택 사용: DFS는 스택을 사용하여 깊이 방향으로 노드를 탐색합니다. 따라서 현재 경로를 스택에 저장하면서, 현재까지 탐색한 경로를 기억하게 된다.
2. 재귀 구조: DFS를 재귀적으로 구현할 수 있으며, 재귀 호출은 현재 경로를 계속해서 업데이트하면서 경로의 특징을 저장할 수 있다.
3. Backtracking: DFS는 Backtracking을 사용하여 목표 노드에 도달할 때까지 탐색하며, 경로의 특징을 저장하고 필요한 경우 되돌아가면서 다른 경로를 탐색합니다
아래는 DFS를 활용해 구현한 문제에 대한 코드이다.
def dfs(num, arr):
arr[num] = -2
for i in range(len(arr)):
if num == arr[i]:
dfs(i, arr)
def(k, arr)
print("모든 재귀호출이 끝난 후:", arr)
이런식으로 DFS를 통해 노드의 정보를 갱신하고 나면 아래와 같이 노드들이 구성되어있을 것이다.
ex)예제출력 4
>>> 모든 재귀호출이 끝난 후: [-1, 0, 0, 2, -2, -2, -2, -2, -2]
여기서 -2로 변한 노드는 모두 삭제될 노드이다.
또한 tree에는 부모의 정보가 들어있으므로 i라는 값이 tree안에 있으면 자식이 있는 노드이다. 위에 예시같은 경우에선 i=1때와 i=3일때가 리프노드이다.
이를 통해 다시 정리해보면 리프노드의 조건은 아래와 같다.
1번 : tree[i] 가 -2이면 지운노드이거나 지운노드의 자식의 노드이다.
2번 : tree에는 부모의 정보가 들어있으므로 i라는 값이 tree안에 있으면 자식이 있는 노드이다. 따라서 i not in tree 를 사용해 체크한다.
count = 0
for i in range(len(arr)):
if arr[i] != -2 and i not in arr:
count += 1
print(i)
import sys
input = sys.stdin.readline
n = int(input())
arr = list(map(int, input().split()))
k = int(input())
count = 0
def dfs(num, arr):
arr[num] = -2
for i in range(len(arr)):
if num == arr[i]:
dfs(i, arr)
dfs(k, arr)
for i in range(len(arr)):
if arr[i] != -2 and i not in arr:
count += 1
print(i)
print(count)