[백준, python] 1197번 - 최소 스패닝 트리
백준 문제풀이 시작
문제
그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.
최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.
입력
첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.
그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.
출력
첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.
풀이
풀이 확인하기
해당 문제는 최소 스패닝 트리(MST)인지 확인하는 문제이다. MST를 구하기 위해 사용하는 알고리즘이 크루스칼 알고리즘(Kruskal Algorithm)인데, 이는 즉 해당 문제가 크루스칼 알고리즘을 구현하는 문제라고 볼 수 있는 것이다. 크루스칼 알고리즘을 위해 cycle을 판단하는 구조인 union-find 구조를 구현하고, 형태에 맞게 입력을 받은 뒤, 입력받은 edge를 가중치를 기준으로 오름차순 정렬을 해준다. 그리고 edge를 순회하며 만약 해당 edge를 추가했을 때 cycle이 발생하지 않는다면 해당 edge를 union하며 값을 추가한다. (cycle이 발생하면 edge를 하나 삭제함으로써 가중치를 줄일 수 있기 때문.)
import sys
input = sys.stdin.readline
def find(p, a):
if(p[a] != a):
p[a] = find(p, p[a])
return p[a]
def union(p, a, b):
a = find(p, a)
b = find(p, b)
if(a < b):
p[b] = a
else:
p[a] = b
v, e = map(int, input().split())
p = [0] * (v + 1)
edge = []
cost = 0
for i in range(1, v+1):
p[i] = i
for _ in range(e):
a, b, c = map(int, input().split())
edge.append((c, a, b))
edge.sort()
for ed in edge:
c, a, b = ed
if(find(p, a) != find(p, b)):
union(p, a, b)
cost += c
print(cost)