[ 최소 신장 트리 ] 알고리즘

2023. 7. 19. 16:08개발/👾 PS

 

 

1.  개념

 

최소 신장 트리 : 그래프에서 모든 노드를 연결할 때 사용된 에지들의 가중치의 최소 합을 구하는 알고리즘

 

특징

- 사이클이 존재할 때 사용 불가

- 노드가 n개일 때 최소 신장 트리를 구성하는 에지의 개수는 항상 (n-1)

 

 

 

2.  구현

 

( 1 )  초기 설정

- 에지 리스트로 그래프 구현 -> 가중치를 기준으로 오름차순 정렬

- 유니온 파인드 리스트 생성 -> 사이클 유무 판단에 사용

 

( 2 )  경로 생성

-  가중치가 가장 작은 에지부터 탐색 시작

-  유니온 파인드에서 find를 이용해 사이클 유무를 판단

    * 에지를 이루는 두 노드를 find 연산했을 때

        대표 노드가 서로 같으면 -> 사이클 존재

        대표 노드가 서로 다르면 -> 사이클 존재X 

-  사이클이 없다면 union을 이용해 노드를 연결해준다

-  연결한 에지가 n-1 개가 될 때까지 반복

 

 

유니온 파인드 알고리즘

https://sosoeunii.tistory.com/56

 

[유니온 파인드] 알고리즘

1. 개념 유니온 파인드 알고리즘은 기본적으로 그래프 자료 구조를 이용한다 Union( 노드를 한 집합으로 합치고 ) + Find( 선택한 노드가 포함되어 있는 집합의 대표 노드 찾기 ) 유니온 파인드 알고

sosoeunii.tistory.com

 

 

 

3.  문제

 

# 1  최소 스패닝 트리 - 백준 1197번

https://www.acmicpc.net/problem/1197

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

 

 

import sys
from queue import PriorityQueue
input = sys.stdin.readline

n, m = map(int, input().split())

edges = PriorityQueue() # 에지 리스트 생성
uf = [i for i in range(n+1)] # 유니온 파인드 리스트 초기화

# 에지 리스트 -> 그래프 구현
for i in range(m):
    u, v, w = map(int, input().split())
    edges.put((w, u, v))


# find 연산
def find(a):
    if a == uf[a]:
        return a
    else:
        uf[a] = find(uf[a])
        return uf[a]       


#union 연산
def union(a, b):
    a = find(a)
    b = find(b)
    if a < b:
    	uf[b] = a
    else:
        uf[a] = b


# MST 실행
edgeNum = 0
ans = 0

#연결한 에지가 n-1 개가 될 때까지
while edgeNum < n-1:
    w, s, e = edges.get()
    
    # 사이클이 없을 때만 union 연산 실행
    if find(s) != find(e):
        union(s, e)
        edgeNum += 1
        ans += w     # 가중치 업데이트

print(ans)

 

 

 

 

 

 

# 2  다리 만들기 - 백준 17472번

https://www.acmicpc.net/problem/17472

 

17472번: 다리 만들기 2

첫째 줄에 지도의 세로 크기 N과 가로 크기 M이 주어진다. 둘째 줄부터 N개의 줄에 지도의 정보가 주어진다. 각 줄은 M개의 수로 이루어져 있으며, 수는 0 또는 1이다. 0은 바다, 1은 땅을 의미한다.

www.acmicpc.net

 

 

from collections import deque
from queue import PriorityQueue
import sys
input = sys.stdin.readline


n, m = map(int, input().split())
island = [[0 for i in range(m)] for j in range(n)]      # 섬 지도
visited = [[False for i in range(m)] for j in range(n)] # 섬 방문 기록
d = [(1, 0), (0, 1), (0, -1), (-1, 0)]					# BFS 이동 
global curr												# 섬 이름
curr = 1
edges = PriorityQueue()									# 에지 리스트


# 섬 지도 입력받기
for i in range(n):
    num = list(map(int, input().split()))
    island[i] = num


# 섬에 이름 부여하기 -> BFS 탐색
def bfs(i, j):
    global curr 	# 섬 이름
    queue = deque()
    queue.append([i, j])	# 큐에 시작 섬 넣기
    visited[i][j] = True	# 시작 섬 방문 체크
    island[i][j] = curr		# 시작 섬에 이름(1) 부여하기
    
    # 섬의 상하 좌우 탐색해서 육지(1)이고 아직 방문하지 않은 섬인 경우
    ## 큐에 해당 좌표(이동 좌표) 삽입
    ## 해당 좌표 방문 체크
    ## 해당 좌표에 이름 부여 
    while queue:
        now_y, now_x = queue.popleft()
        for dc, dr in d:
            next_x = now_x + dc
            next_y = now_y + dr
            # 탐색 좌표가 지도 내에 있어야 함
            if next_x >= 0 and next_y >= 0 and next_x < m and next_y < n:
                if island[next_y][next_x] == 1 and not visited[next_y][next_x]:
                    visited[next_y][next_x] = True
                    queue.append([next_y, next_x])
                    island[next_y][next_x] = curr
  
    # 섬 한 개 탐색 끝났으니 섬 이름 업데이트
    curr += 1 


# 다리 이어주기 (BFS 탐색) -> 에지리스트 생성
def bridge(i, j):
    queue = deque()
    queue.append((i, j))
    while queue:
        now_y, now_x = queue.popleft()
        for dc, dr in d:
            next_x = now_x + dr
            next_y = now_y + dc
            length = 0         # 섬과 섬 사이 거리
            
            # 탐색 좌표가 바다(0)면
            ## 탐색 방향으로 탐색 계속 이어가기 -> 지도 끝에 도달하거나 육지 만날 때까지 
            while True:
            	# 탐색 좌표가 지도 내에 있어야 함
                if next_x >= 0 and next_y >= 0 and next_x < m and next_y < n:
                    if island[next_y][next_x] == 0:
                        next_x += dr
                        next_y += dc
                        length += 1
                    # '다른' 섬에 도달했고 섬과 섬 사이 거리가 2이상인 경우 -> 에지리스트에 추가
                    elif island[next_y][next_x] != island[now_y][now_x] and length > 1:
                        edges.put([length, island[now_y][now_x], island[next_y][next_x]])
                        break
                    else:
                        break
                else:
                    break


# 최소 신장 트리 알고리즘

def find(a):
    if a == uf[a]:
        return a
    else:
        uf[a] = find(uf[a])
        return uf[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a < b:
        uf[b] = a
    else:
        uf[a] = b


for i in range(n):
    for j in range(m):
        if island[i][j] == 1 and not visited[i][j]:
            bfs(i, j)

for i in range(n):
    for j in range(m):
        if island[i][j] != 0:
            bridge(i, j)

uf = [i for i in range(curr)]
edgeNum = 0
ans = 0
connect = False
while edges.qsize() > 0:
    w, s, e = edges.get()
    if find(s-1) != find(e-1):
        union(s-1, e-1)
        edgeNum += 1
        ans += w
    if edgeNum == curr - 2:
        connect = True

if connect:
    print(ans)
else:
    print(-1)

 

 

으아아아ㅏㄹ어링머;리머;ㅣ아ㅓ;니ㅏㅓㄹㄴ;ㅣㅏㅓ;ㅣㅏㅓ 풀었다알ㅇㄹ;ㅣ아러;미ㅏㄹ어ㅠㅠㅠㅠㅠㅠㅠㅠ

이것 때문에 열받아서 코딩 며칠간 놓고 있었는데 

머리 완전 비우고 다시 봤더니 완전 멍청한 실수를 하고 있었다 ㅇ러미;ㅇ로ㅓ오...

유니온 파인드 리스트를 섬의 개수대로 만들었어야 했는데

무슨 지도의 가로 길이 이딴 걸로 만들었다

으아아아 어쨌든 너무 후련하구 뿌듯하당 히히

 

 

# 3  불우이웃돕기 - 백준 1414번

https://www.acmicpc.net/problem/1414

 

1414번: 불우이웃돕기

첫째 줄에 컴퓨터의 개수 N이 주어진다. 둘째 줄부터 랜선의 길이가 주어진다. i번째 줄의 j번째 문자가 0인 경우는 컴퓨터 i와 컴퓨터 j를 연결하는 랜선이 없음을 의미한다. 그 외의 경우는 랜선

www.acmicpc.net

 

 

import sys
from queue import PriorityQueue
input = sys.stdin.readline

n = int(input())
edges = PriorityQueue()
total = 0
lan = []
answer = False

for i in range(n):
    al = input()
    lan.append(al)

for i in range(n):
    for j in range(n):
        if lan[i][j] >= 'A' and lan[i][j] <= 'Z':
            length = ord(lan[i][j]) - 38
            total += length
            edges.put((length, i, j))
        elif lan[i][j] >= 'a' and lan[i][j] <= 'z':
            length = ord(lan[i][j]) - 96
            total += length
            edges.put((length, i, j))

uf = [i for i in range(n+1)]


def find(a):
    if a == uf[a]:
        return a
    else:
        uf[a] = find(uf[a])
        return uf[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a < b:
        uf[b] = a
    else:
        uf[a] = b


edgeNum = 0
ans = 0

while edges.qsize() > 0:
    w, s, e = edges.get()
    if find(s) != find(e):
        union(s, e)
        edgeNum += 1
        ans += w
    if edgeNum == n-1:
        answer = True

if n == 1:
    if al == '0':
        print(0)
    else:
        print(total - ans)
else:
    if answer:
        print(total - ans)
    else:
        print(-1)

 

얘도 똑같이 풀면 된당

'개발 > 👾 PS' 카테고리의 다른 글

[ 트라이 ] 알고리즘  (0) 2023.07.22
[ 트리 ] 알고리즘  (0) 2023.07.19
[ 플로이드 - 워셜 ] 알고리즘  (0) 2023.07.18
[ 벨만 - 포드 ] 알고리즘  (0) 2023.07.15
🌱스트릭 잇기🌱 7월 1주차  (0) 2023.07.10