백준 #1717 - 집합의 표현
문제
초기에 n + 1개의 집합 {0},{1},{2},…,{n 이 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다.
집합을 표현하는 프로그램을 작성하시오.
입력
첫째 줄에 n, m이 주어진다. 은 입력으로 주어지는 연산의 개수이다. 다음 개의 줄에는 각각의 연산이 주어진다. 합집합은 0 a b의 형태로 입력이 주어진다. 이는 가 포함되어 있는 집합과, 가 포함되어 있는 집합을 합친다는 의미이다. 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산은 q a b의 형태로 입력이 주어진다. 이는 a와 b가 같은 집합에 포함되어 있는지를 확인하는 연산이다.
출력
1로 시작하는 입력에 대해서 a와 b가 같은 집합에 포함되어 있으면 "YES" 또는 "yes"를, 그렇지 않다면 "NO" 또는 "no"를 한 줄에 하나씩 출력한다.
제한
- 1≤ n ≤1000000
- 1≤ m ≤100000
- 0 ≤ a, b ≤ n
- a, b는 정수
- 와 는 같을 수도 있다.
예제 입력
7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1
예제 출력
NO
NO
YES
일단 처음에는 이차원 배열을 사용하여 그래프를 구현하고, 같은 집합이면 1, 다른 집합이면 0으로 저장하는 방식을 사용했다.
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
sets = [[0] * (n + 1) for _ in range(n + 1)]
for i in range(m):
operator, a, b = map(int, input().split())
if operator == 0:
sets[a][b] = 1
sets[b][a] = 1
elif operator == 1:
if sets[a][b] == 1:
print("YES")
else:
print("NO")
그런데 메모리 초과가 떴다.
시간 초과는 많이 봤어도 메모리 초과는 처음 봐서 당황스러웠다...
생각해보니 이차원 배열로 그래프를 구현하려면 너무 큰 메모리를 사용하기도 하고,
이런 방식으로는 집합 개념을 저장할 수 없겠다 싶었다.
그러던 중 union-find 방식으로 합집합 문제를 푼다는 것을 알게 되었다.
union-find 방식은 다음과 같다.
우선, 초기 리스트는 위와 같이 생겼다.
길이가 (n + 1)인 리스트를 생성하고, 초기 원소 값으로는 인덱스를 가진다.
여기서 list[2]는 정수 2가 속한 집합을 의미한다.
여기서 정수 1이 속한 집합과 정수 3이 속한 집합을, 정수 4가 속한 집합과 정수 5가 속한 집합을 합치면 어떻게 표시될까?
이렇게 표시될 것이다.
그렇다면 다시 문제를 풀어보자.
import sys
sys.setrecursionlimit(10**6) # 재귀 제한
def union(x, y):
x = sets(x)
y = sets(y)
if x == y: # 두 원소가 속한 집합이 같지 않다면 하나의 집합으로 합치기
return
else:
nums[y] = x
def sets(target): # 재귀적으로 원소기 속한 집합 찾기
if target == nums[target]:
return target
nums[target] = sets(nums[target])
return nums[target]
n, m = map(int, sys.stdin.readline().split())
nums = [i for i in range(n + 1)] # [1 2 3 4 5 ... n]
for _ in range(m):
operator, a, b = map(int, sys.stdin.readline().split())
if operator == 0:
union(a, b)
elif operator == 1:
if sets(a) == sets(b):
print("YES")
else:
print("NO")
sets 함수는 원소가 속한 집합을 찾는 함수이다.
마지막 메인 호출 부분에서 operator이 1이고, 두 원소가 속한 집합이 같다면 YES를, 다르다면 NO를 출력한다.
union 함수는 합집합 연산을 수행하는 함수이다.
만일 두 원소가 속한 집합이 같다면 아무 연산도 수행하지 않지만, 속한 집합이 같지 않다면 하나의 집합으로 합친다.
그런데 sets 함수를 보자.
def sets(target): # 재귀적으로 원소의 붐 찾기
return nums[target]
이렇게 간단하게 구성하면 안될까?
안된다.
챗지피티의 말에 따르면 다음과 같다.
부모를 반환하고 수정하는 이유는 "경로 압축(Path Compression)"을 통해 재귀 호출의 깊이를 줄이고 실행 시간을 최적화하기 위해서입니다.
sets 함수에서 재귀적으로 부모를 찾을 때, 부모를 찾은 후 해당 원소의 부모를 부모의 부모로 갱신함으로써 경로를 짧게 만듭니다. 이렇게 하면 같은 원소에 대한 다음 호출에서는 원소의 부모를 바로 참조할 수 있으므로 재귀 호출의 깊이가 줄어들어 실행 시간이 단축됩니다.
경로 압축은 집합 요소들이 연결되어 있는 경우 그래프의 깊이를 줄여서 탐색이 빨라지게 합니다. 따라서 sets 함수에서 경로 압축을 통해 부모를 수정하는 것이 중요합니다.
오호....
아무래도 이런 집합 문제를 풀 때는 코드를 기억해 가는게 편할 것 같다.
집합을 다루는 특수한 알고리즘이 있는 줄은 처음 알았다.