[BOJ 2213] - 트리의 독립집합 (DP, 트리, C++, Python)

보양쿠·2023년 9월 21일
0

BOJ

목록 보기
196/252

BOJ 2213 - 트리의 독립집합 링크
(2023.09.21 기준 G1)

문제

정점 n개로 이루어진 트리와 각 정점마다 가중치가 주어진다. 인접한 정점 쌍이 없게끔 정점들을 선택했을 때의 최대 가중치의 합 출력

알고리즘

include 상태에 따른 트리에서의 DP

풀이

하나의 정점이 선택되었다면, 그 정점과 인접한 정점들은 선택이 불가능하다.
반대로 선택되지 않는다면, 그 정점과 인접한 정점들은 선택이 가능하다. 물론 무조건 선택해야 하는 것은 아니다.

dp[vertex][exclude]로 정의하자. exclude가 0이면 vertex가 포함, 1이면 vertex가 포함되지 않는 상태다.
그리고 vertex의 부모를 제외한 인접한 정점들을 children, 그 중 하나를 child라고 하자. 그리고 가중치는 w라는 배열에 담겨져 있다고 하자.

  • 먼저, vertex가 포함되는 상태의 dp 값은
    dp[vertex][0] = w[vertex] + sum(dp[child][1] for child in children) 이다.
    위에서 서술했듯이, vertex가 포함되면 인접한 정점들은 선택이 불가능하다. 그러므로 각 child가 선택이 되지 않았을 때의 dp 값의 합 + vertex의 가중치가 곧, vertex가 선택되었을 때의 dp 값이 된다.
  • vertex가 포함되지 않는 상태의 dp 값은?
    dp[vertex][1] = sum(max(dp[child]) for child in children) 이다.
    vertex가 포함되지 않았다고 무조건 child를 선택해야 하는 것이 아니다. child를 선택하지 않았을 때의 dp 값이 더 높을 수도 있기 때문.
    그러므로 dp[child]의 더 높은 값으로 하여금, child의 dp 값의 합이 vertex가 선택되지 않았을 때의 dp 값이 된다.

위 두 상태에 따라 dp 테이블을 채워 나가면 된다.


그리고 이제, 실제로 어떤 정점이 선택되었는지 알기 위해서 역추적을 해야 한다.
dp 테이블을 채워 나간 순서대로 살펴보면 되는데, 만약 vertex가 포함된 상태의 dp 값이 더 높다? 그러면 vertex가 선택되고, children은 절대 선택될 수 없다.
반대로, vertex가 포함되지 않은 상태가 dp 값이 더 높다? 그러면 vertex는 선택되지 않고, children은 선택이 가능해진다.

선택이 가능한지 판별하기 위한 possible 변수를 하나 더 추가해 탐색하자.

function dfs(int vertex, int parent, bool possible):
	if (possible is true) -> 선택 가능
    {
    	if (dp[vertex][0] > dp[vertex][1]) -> 포함된 상태의 dp 값이 더 높다.
        	include vertex
            nxt = false -> vertex를 선택하고 다음 possible의 변수는 false가 된다.
        else -> 포함되지 않은 상태의 dp 값이 더 높다.
        	exclude vertex
        	nxt = true -> vertex를 선택하지 않고 다음 possible의 변수는 true가 된다.
    }
    else -> 선택 불가능
    {
        exclude vertex
        nxt = true -> vertex를 선택하지 않고 다음 possible의 변수는 true가 된다.
    }

	for (child: children)
    	dfs(child, vertex, nxt)

코드

  • C++
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1e4;

int w[MAXN], dp[MAXN][2]; // 포함 o, 포함 x
vector<int> graph[MAXN], result;

void dfs1(int u, int p){
    dp[u][0] = w[u]; // u번이 포함되는 경우는 u번의 가중치도 포함된다.
    dp[u][1] = 0;
    for (int v: graph[u]){
        if (v == p) continue;
        dfs1(v, u);
        dp[u][0] += dp[v][1]; // u번이 포함되면 v번은 무조건 포함되지 않아야 한다.
        dp[u][1] += max(dp[v][0], dp[v][1]); // u번이 포함되지 않으면 v번은 어느 쪽을 선택해도 상관이 없다.
    }
}

void dfs2(int u, int p, bool possible){ // possible : 포함이 가능한 지 체크
    bool nxt;
    if (possible){ // 포함이 가능하다면 dp 값이 높은 쪽을 선택
        if (dp[u][0] > dp[u][1]){
            result.push_back(u + 1);
            nxt = false; // u번을 선택하게 되면 다음 정점들은 선택하지 못한다.
        }
        else nxt = true; // u번을 선택하지 않으면 다음 정점들은 선택이 가능해진다.
    }
    else // 포함이 불가능하다면 포함되지 않으므로
        nxt = true; // 다음 정점들은 선택이 가능해진다.

    for (int v: graph[u]){
        if (v == p) continue;
        dfs2(v, u, nxt);
    }
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    int n; cin >> n;
    for (int i = 0; i < n; i++) cin >> w[i];
    for (int i = 1, u, v; i < n; i++){
        cin >> u >> v;
        graph[--u].push_back(--v);
        graph[v].push_back(u);
    }

    dfs1(0, -1);

    // 루트로 잡은 0번의 결과 중 높은 값 출력
    cout << max(dp[0][0], dp[0][1]) << '\n';

    // dp 역추적
    dfs2(0, -1, true);
    sort(result.begin(), result.end());
    for (int u: result) cout << u << ' ';
}
  • Python
import sys; input = sys.stdin.readline

def dfs1(u, p):
    dp[u][0] = w[u] # u번이 포함되는 경우는 u번의 가중치도 포함된다.
    for v in graph[u]:
        if v == p:
            continue
        dfs1(v, u)
        dp[u][0] += dp[v][1] # u번이 포함되면 v번은 무조건 포함되지 않아야 한다.
        dp[u][1] += max(dp[v]) # u번이 포함되지 않으면 v번은 어느 쪽을 선택해도 상관이 없다.

def dfs2(u, p, possible): # possible : 포함이 가능한 지 체크
    if possible: # 포함이 가능하다면 dp 값이 높은 쪽을 선택
        if dp[u][0] > dp[u][1]:
            result.append(u + 1)
            nxt = False # u번을 선택하게 되면 다음 정점들은 선택하지 못한다.
        else:
            nxt = True # u번을 선택하지 않으면 다음 정점들은 선택이 가능해진다.
    else: # 포함이 불가능하다면 포함되지 않으므로
        nxt = True # 다음 정점들은 선택이 가능해진다.

    for v in graph[u]:
        if v == p:
            continue
        dfs2(v, u, nxt)

n = int(input())
w = list(map(int, input().split()))
graph = [[] for _ in range(n)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    u -= 1; v -= 1
    graph[u].append(v)
    graph[v].append(u)

dp = [[0] * 2 for _ in range(n)] # 포함 o, 포함 x
dfs1(0, -1)

# 루트로 잡은 0번의 결과 중 높은 값 출력
print(max(dp[0]))

# dp 역추적
result = []
dfs2(0, -1, True)
print(*sorted(result))
profile
GNU 16 statistics & computer science

0개의 댓글