피너클의 it공부방
백준 13511 트리와 쿼리 2 (c++) : 피너클 본문
https://www.acmicpc.net/problem/13511
lca로 풀었다.
문제에서 2개의 값을 구해야 한다.
경로의 비용과, 경로에 존재하는 정점중 k번째 정점
둘다 기존의 lca를 개조해서 풀수있다.
int n, m;
vector<pair<long long, long long>> graph[100001];
long long parent[21][100001];
long long weight[21][100001];
int depth[100001];
사용할 변수들이다.
대부분 기존의 것에서 사용하는 것이며 weight만 추가된것이다.
weight는 당연히 비용이다.
int main(int argc, char** argv)
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL);
cin >> n;
for (int i = 0; i < n-1; i++) {
long long u, v, w;
cin >> u >> v >> w;
graph[u].push_back({ v, w });
graph[v].push_back({ u, w });
}
memset(parent, -1, sizeof(parent));
memset(weight, -1, sizeof(weight));
parent[0][1] = 0;
weight[0][1] = 0;
depth[1] = 1;
dfs(1, 0);
값을 입력받고 전부 -1로 초기화 한뒤
dfs를 돌린다.
void dfs(int now, int p) {
for (int i = 0; i < graph[now].size(); i++) {
int next = graph[now][i].first;
int next_weight = graph[now][i].second;
if (next == p) continue;
parent[0][next] = now;
weight[0][next] = next_weight;
depth[next] = depth[now] + 1;
dfs(next, now);
}
}
1번에서 시작해서 계속 나아가며 weight와 parent, depth를 채워준다.
dfs(1, 0);
for (int h = 1; h <= 20; h++) {
for (int i = 1; i <= n; i++) {
if (parent[h - 1][i] == -1) continue;
parent[h][i] = parent[h - 1][parent[h - 1][i]];
weight[h][i] = weight[h - 1][i] + weight[h - 1][parent[h - 1][i]];
}
}
dfs가 끝나면 parent를 추가로 채워준다.
weight는 i의 2^(h-1) 정점의 비용 + i의 2^(h-1) 정점 의 2^(h-1)의 정점의 비용이다.
a, b, c가 있고 a는 b의 부모, b는 c의 부모라고 하면
parent[0][c] = b, parent[0][b] = a이다.
parent[1][c] = parent[0][parent[0][c]] = parent[0][b] = parent[a]가 되고
weight[1][c] = weight[0][c] + weight[0][parent[0][c]] = weight[0][c] + weight[0][b] 가 된다.
c에서 2^0 까지의 비용과 b에서 2^0 까지의 비용이 들어가 중간의 모든 비용이 들어가게 된다.
while (m-- > 0) {
int a, u, v, k;
cin >> a;
if (a == 1) {
cin >> u >> v;
cout << getCost(u, v) << '\n';
}
else {
cin >> u >> v >> k;
cout << find(u, v, k) << '\n';
}
}
그 후에는 입력을 받고 출력해주면 된다.
long long getCost(int u, int v) {
long long cost = 0;
if (depth[u] != depth[v]) {
if (depth[u] < depth[v]) swap(u, v);
for (int h = 20; h >= 0; h--) {
if (depth[u] - depth[v] >= (1 << h)) {
cost += weight[h][u];
u = parent[h][u];
}
}
}
if (u == v) return cost;
for (int h = 20; h >= 0; h--) {
if (parent[h][u] != -1 && parent[h][u] != parent[h][v]) {
cost += weight[h][u];
cost += weight[h][v];
u = parent[h][u];
v = parent[h][v];
}
}
return cost + weight[0][u] + weight[0][v];
}
getCost함수다. 기존 lca에 cost구하는것만 추가해주면 된다.
마지막에 weight[0][u] 와 weight[0][v]를 추가해주는걸 잊지 않고
u = parent[h][u]를 하기 전에 cost를 추가해주는걸 잊지 않아야 한다.
u = parent[h][u]하고 cost추가하면 이상한값 들어간다.
int find(int u, int v, int k) {
int l = lca(u, v);
int distU = depth[u] - depth[l];
int distV = depth[v] - depth[l];
if (k <= distU + 1) return solve(u, k - 1);
else return solve(v, distU + distV - k + 1);
}
find는 3개의 단계가 있다.
1. u와 v의 최소 공통 조상을 찾는다. 이게 l이다.
2. 최소 공통 조상까지의 경로의 길이를 구한다.
3. 각자 구한 경로의 길이를 토대로 k를 어디에서 찾아야 하는지를 확인한다.
int lca(int u, int v) {
if (depth[u] != depth[v]) {
if (depth[u] < depth[v]) swap(u, v);
for (int h = 20; h >= 0; h--) {
if (depth[u] - depth[v] >= (1 << h)) {
u = parent[h][u];
}
}
}
if (u == v) return u;
for (int h = 20; h >= 0; h--) {
if (parent[h][u] != -1 && parent[h][u] != parent[h][v]) {
u = parent[h][u];
v = parent[h][v];
}
}
return parent[0][u];
}
먼저 lca이다. 그냥 똑같다.
int find(int u, int v, int k) {
int l = lca(u, v);
int distU = depth[u] - depth[l];
int distV = depth[v] - depth[l];
if (k <= distU + 1) return solve(u, k - 1);
else return solve(v, distU + distV - k + 1);
}
l을 구한 다음 u에서 l까지의 길이, v에서 l까지의 길이를 구한다.
depth를 이용하면 간단하게 구해진다.
이제 조심해야하는게 k는 1부터이다. 우리같은 사람은 0부터 세지만 이놈은 1부터 셌다.
그래서 distU + 1이랑 비교해준다. 이놈보다 작거나 같으면 u에서 k-1만큼 올라가면 된다.
하지만 크다면 얘기가 달라진다.
경로는 u -> l -> v이고 이는 distU + distV이다.
우리는 v에서 올라올것이기 때문에 distu + distV - v + 1로 해준다. -1아니다. +1이다.
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;
int n, m;
vector<pair<long long, long long>> graph[100001];
long long parent[21][100001];
long long weight[21][100001];
int depth[100001];
void dfs(int now, int p) {
for (int i = 0; i < graph[now].size(); i++) {
int next = graph[now][i].first;
int next_weight = graph[now][i].second;
if (next == p) continue;
parent[0][next] = now;
weight[0][next] = next_weight;
depth[next] = depth[now] + 1;
dfs(next, now);
}
}
long long getCost(int u, int v) {
long long cost = 0;
if (depth[u] != depth[v]) {
if (depth[u] < depth[v]) swap(u, v);
for (int h = 20; h >= 0; h--) {
if (depth[u] - depth[v] >= (1 << h)) {
cost += weight[h][u];
u = parent[h][u];
}
}
}
if (u == v) return cost;
for (int h = 20; h >= 0; h--) {
if (parent[h][u] != -1 && parent[h][u] != parent[h][v]) {
cost += weight[h][u];
cost += weight[h][v];
u = parent[h][u];
v = parent[h][v];
}
}
return cost + weight[0][u] + weight[0][v];
}
int lca(int u, int v) {
if (depth[u] != depth[v]) {
if (depth[u] < depth[v]) swap(u, v);
for (int h = 20; h >= 0; h--) {
if (depth[u] - depth[v] >= (1 << h)) {
u = parent[h][u];
}
}
}
if (u == v) return u;
for (int h = 20; h >= 0; h--) {
if (parent[h][u] != -1 && parent[h][u] != parent[h][v]) {
u = parent[h][u];
v = parent[h][v];
}
}
return parent[0][u];
}
int solve(int u, int k) {
for (int h = 20; h >= 0; h--) {
if (k - (1 << h) >= 0) {
k -= (1 << h);
u = parent[h][u];
}
}
return u;
}
int find(int u, int v, int k) {
int l = lca(u, v);
int distU = depth[u] - depth[l];
int distV = depth[v] - depth[l];
if (k <= distU + 1) return solve(u, k - 1);
else return solve(v, distU + distV - k + 1);
}
int main(int argc, char** argv)
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL);
cin >> n;
for (int i = 0; i < n-1; i++) {
long long u, v, w;
cin >> u >> v >> w;
graph[u].push_back({ v, w });
graph[v].push_back({ u, w });
}
memset(parent, -1, sizeof(parent));
memset(weight, -1, sizeof(weight));
parent[0][1] = 0;
weight[0][1] = 0;
depth[1] = 1;
dfs(1, 0);
for (int h = 1; h <= 20; h++) {
for (int i = 1; i <= n; i++) {
if (parent[h - 1][i] == -1) continue;
parent[h][i] = parent[h - 1][parent[h - 1][i]];
weight[h][i] = weight[h - 1][i] + weight[h - 1][parent[h - 1][i]];
}
}
cin >> m;
while (m-- > 0) {
int a, u, v, k;
cin >> a;
if (a == 1) {
cin >> u >> v;
cout << getCost(u, v) << '\n';
}
else {
cin >> u >> v >> k;
cout << find(u, v, k) << '\n';
}
}
return 0;
}
전체코드다.
'백준' 카테고리의 다른 글
| 백준 15927 회문은 회문아니야 (c++) : 피너클 (0) | 2025.09.13 |
|---|---|
| 백준 3392 화성 지도 (c++) : 피너클 (0) | 2025.09.02 |
| 백준 14868 문명 (c++) : 피너클 (0) | 2025.08.30 |
| 백준 5710 거의 최단 경로 (c++) : 피너클 (0) | 2025.08.29 |
| 백준 11378 열혈강호 4 (c++) : 피너클 (2) | 2025.08.23 |