https://www.acmicpc.net/problem/19593
"""
Title : 다도해
Link : https://www.acmicpc.net/problem/19593
"""
import sys
input = sys.stdin.readline
def find_parent(a, parents):
if a == parents[a]:
return a
parents[a] = find_parent(parents[a], parents)
return parents[a]
def union(a, b, parents):
pa = find_parent(a, parents)
pb = find_parent(b, parents)
if pa == pb:
return
if pa < pb:
for i in range(N):
if parents[i] == pb:
parents[i] = pa
else:
for i in range(N):
if parents[i] == pa:
parents[i] = pb
return
def solve(N, Seed, A, B):
X, Y = 0, 0
Nsq = N * N
parents = [i for i in range(N)]
isVisited = [False for _ in range(Nsq)]
answer = 0
E = Seed % Nsq
while True:
if isVisited[E] == True:
return 0
isVisited[E] = True
answer += 1
X = E // N
Y = E % N
if X == Y:
E = (E * A + B) % Nsq
continue
# Union Find
union(X, Y, parents)
# check if all islands are connected
cnt = 0
for i in range(N):
if parents[i] == 0:
cnt += 1
if cnt == N:
break
E = (E * A + B) % Nsq
return answer
t = int(input())
for _ in range(t):
N, Seed, A, B = map(int, input().split())
print(solve(N, Seed, A, B))
사이클이 생기지 않게 노드를 연결하면 전체 노드개수(N)의 N-1개를 연결하면 반드시 전체 노드가 연결이 됨.
따라서 union함수에서 연결이 이뤄줬는지 리턴값으로 주고
connections 변수로 총연결개수를 계산해서 N-1 이상이면 브레이크 하도록 해야함.
import sys
input = sys.stdin.readline
def find_parent(a, parents):
if a == parents[a]:
return a
parents[a] = find_parent(parents[a], parents)
return parents[a]
def union(a, b, parents):
pa = find_parent(a, parents)
pb = find_parent(b, parents)
if pa == pb:
return 0
if pa < pb:
parents[pb] = pa
else:
parents[pa] = pb
return 1
def solve(N, Seed, A, B):
X, Y = 0, 0
Nsq = N * N
parents = [i for i in range(N)]
isVisited = [False for _ in range(Nsq)]
answer = 0
E = Seed % Nsq
connections = 0
while True:
if isVisited[E] == True:
return 0
isVisited[E] = True
answer += 1
X = E // N
Y = E % N
if X == Y:
E = (E * A + B) % Nsq
continue
# Union Find
connections += union(X, Y, parents)
# check all islands are connected
if connections >= N-1:
break
E = (E * A + B) % Nsq
return answer
t = int(input())
for _ in range(t):
N, Seed, A, B = map(int, input().split())
print(solve(N, Seed, A, B))