[Union Find 응용] 무인도 여행하기

임현규·2023년 9월 7일
0

알고리즘

목록 보기
3/4

문제 접근하기

https://school.programmers.co.kr/learn/courses/30/lessons/154540?language=java

2차원 배열이 존재하고 섬을 숫자들이 상하좌우로 인접한 경우에 인접한 숫자들의 모음을 섬으로 정의한다. 이 때 각 섬들의 value의 합들을 배열로 정렬해서 리턴해야 한다.

이 문제는 DFS, BFS를 이용해서 인접한 value들을 합한다. 그 과정을 반복해서 수행하면 섬들의 value 합을 구할 수 있다.

그러나 이 문제는 유니온 파인드 문제로도 풀 수 있다. UnionFind는 조상을 찾아서 묶어주는 자료구조인데, 이를 활용하면 merge를 통해서 공통 node를 묶을 수 있다.

그리고 공통 노드를 조회하면서 각 값들의 합을 구하면 이 문제를 풀 수 있다.

UnionFind

# python
class UnionFind:

    def __init__(self, n):
        self.parents = [i for i in range(n)]

    def find(self, x):
        if self.parents[x] != x:
            self.parents[x] = self.find(self.parents[x])
        return self.parents[x]

    def union(self, x, y):
        parent_x = self.find(x)
        parent_y = self.find(y)
        self.parents[parent_y] = parent_x

매우 심플한 UnionFind이다. 경로 압축 정도만 해도 문제를 푸는데 충분하다. 만약 더 최적화를 하고 싶다면 rank에 따라 union시 조상을 선택해서 깊이를 줄일 수 있다.

문제 풀기

알고리즘은 간단하다.

  1. UnionFind를 선언한다.
  2. 2차원 배열을 순회하면서 우측, 아래 값들이 같으면 merge한다.
  3. 모든 배열을 순회하며 UnionFind의 같은 부모의 value들의 합을 Map에 저장한다.
  4. 정렬된 배열 형태로 리턴한다.
# 풀이
import sys
from collections import defaultdict

sys.setrecursionlimit(10 ** 7)


class UnionFind:

    def __init__(self, n):
        self.parents = [i for i in range(n)]

    def find(self, x):
        if self.parents[x] != x:
            self.parents[x] = self.find(self.parents[x])
        return self.parents[x]

    def union(self, x, y):
        parent_x = self.find(x)
        parent_y = self.find(y)
        self.parents[parent_y] = parent_x

    def is_connect(self, x, y):
        return self.find(x) == self.find(y)


def solution(maps):
    rows = len(maps)
    cols = len(maps[0])

    def is_in_map(x, y):
        return 0 <= x < rows and 0 <= y < cols

    def is_number(x, y):
        return maps[x][y] != 'X'

    uf = UnionFind(rows * cols)
    for i in range(rows):
        for j in range(cols):
            node = i * cols + j
            right_node = i * cols + j + 1
            down_node = (i + 1) * cols + j
            if is_in_map(i, j + 1) and is_number(i, j + 1):
                uf.union(node, right_node)
            if is_in_map(i + 1, j) and is_number(i + 1, j):
                uf.union(node, down_node)

    islands = defaultdict(int)
    for i in range(rows):
        for j in range(cols):
            if is_number(i, j):
                islands[uf.find(i * cols + j)] += int(maps[i][j])
    return sorted(islands.values()) if islands else [-1]
// java
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;

class Solution {

    public int[] solution(String[] maps) {
        UnionFind uf = merge(maps);
        Map<Integer, Integer> islands = new HashMap<>();
        for (int i = 0; i < maps.length; ++i) {
            for (int j = 0; j < maps[0].length(); ++j) {
                if (maps[i].charAt(j) == 'X') {
                    continue;
                }
                int node = i * maps[0].length() + j;
                int root = uf.find(node);
                int value = maps[i].charAt(j) - '0';
                islands.put(root, islands.getOrDefault(root, 0) + value);
            }
        }
        int[] result =  islands.values().stream().mapToInt(Integer::intValue).sorted().toArray();
        return result.length == 0 ? new int[]{-1} : result;
    }

    private UnionFind merge(String[] maps) {
        int rows = maps.length;
        int cols = maps[0].length();
        UnionFind uf = new UnionFind(rows * cols);
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < cols; ++j) {
                if (maps[i].charAt(j) == 'X') {
                    continue;
                }
                int node = i * cols + j;
                int rightNode = node + 1;
                int downNode = node + cols;
                if (j + 1 < cols && maps[i].charAt(j + 1) != 'X') {
                    uf.merge(node, rightNode);
                }
                if (i + 1 < rows && maps[i + 1].charAt(j) != 'X') {
                    uf.merge(node, downNode);
                }
            }
        }
        return uf;
    }

    static class UnionFind {

        private final int[] parent;

        public UnionFind(int size) {
            this.parent = IntStream.range(0, size).toArray();
        }

        public int find(int x) {
            if (parent[x] != x) {
                parent[x] = find(parent[x]);
            }
            return parent[x];
        }

        public void merge(int x, int y) {
            x = find(x);
            y = find(y);
            if (x != y) {
                parent[y] = x;
            }
        }
    }
}
profile
엘 프사이 콩그루

0개의 댓글