피너클의 it공부방
백준 4008 특공대 (c++) : 피너클 본문
https://www.acmicpc.net/problem/4008
dp와 Convex Hull Trick 으로 풀었다.
먼저 dp 식을 만들었다.
dp[i] = 0 ~ i까지 특공대를 조직했을때 최대 전투력
대충 이정도로 생각해보자.
dp[i]를 만들었다고 생각하고 dp[i + 1]로 넘어가자
dp[i + 1]은 dp[i]에서 새로운 병사가 추가되어 x[i + 1]이 추가되어야 한다.
dp[i + 1] = dp[i] + a * x[i + 1] ^ 2 + b * x[i + 1] + c
하지만 새로 추가될 병사인 x[i + 1]은 앞의 인원과 그룹이 될수도 있다.
x[i + 1]과 x[i]이 하나의 특공대가 되게 하려면 dp[i]가 아닌 dp[i - 1]을 사용해야한다.
dp[i + 1] = dp[i - 1] + a * (x[i + 1] + x[i]) ^ 2 + b * (x[i + 1] + x[i]) + c
또 앞의 인원과 그룹이 될수 있다.
dp[i + 1] = dp[i - 2] + a * (x[i + 1] + x[i] + x[i - 1]) ^ 2 + b * (x[i + 1] + x[i] + x[i - 1] ) + c
i+1을 i로 바꾸고 0 ~ i를 j로 바꾸면
이걸 하나의 식으로 만들면
dp[i] = max(dp[j] + a * (x[j + 1 ~ i - 1] + x[i]) ^ 2 + b * (x[j + 1 ~ i - 1] + x[i]) + c)
dp[i] = max(dp[j] + a * (x[j + 1 ~ i]) ^ 2 + b * (x[j + 1 ~ i]) + c)
나는 누적합을 이용할 것이다. i까지의 누적합을 p[i]라고 하면 위의 식을
dp[i] = max(dp[j] + a * (p[i] - p[j]) ^ 2 + b * (p[i] - p[j]) + c) 이렇게 바꿀수있다.
이제 위의 식을 하나하나 풀면
dp[j] + a * (p[i] - p[j]) ^ 2 + b * (p[i] - p[j]) + c
dp[j] + a * (p[i] ^ 2 - 2 * p[i] * p[j] + p[j] ^ 2) + b * (p[i] - p[j]) + c
dp[j] + a * p[i] ^ 2 - 2 * a * p[i] * p[j] + a * p[j] ^ 2 + b * p[i] - b * p[j] + c
이렇게 된다. 이건 직접 종이에 손으로 계산해보는걸 추천한다. 그래야 이해가 된다. 내가 그랬다.
위의 식을 p[i]를 기준으로 나눌 것이다.
(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j] + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)
이것도 직접 손으로 해봐야한다. 그래야 이해가 된다. 내가 그랬다.
위의 식에서 p[i]를 변수 x로 보면 1차 방정식이 보일 것이다.
y = ax + b라고 했을때
a : (- 2 * a * p[j] * p[i])
x : (- 2 * a * p[j] * p[i])
b : (a * p[j] ^ 2 - b * p[j] + dp[j] )
그럼 뒤에있는 (a * p[i] ^ 2 + b * p[i] + c) 이건 어떻게 될까?
이것들은 나중에 사용된다.
이제부터 Convex Hull Trick 이걸 사용해야하는데 이건 미안하지만 다른곳에서 강좌를 봐주길 바란다.
백준 13263 나무 자르기 (c++) : 피너클
13263번: 나무 자르기Convex Hull Trick 으로 풀었다. 문제를 먼저 봐보자.만약 내가 5번 나무의 높이를 0으로 만들면 5번 보다 번호가 큰 나무를 자르기 전까지는전기톱을 충전할때 b[5]만큼의 비용이
pinacle.tistory.com
이건 Convex Hull Trick이용한 다른 문제인데 여기에서 Convex Hull Trick이거 조금 설명했으니 이걸 봐도 될것이다.
여기서는 진짜 조금 설명했다. 왜 이걸 사용해도 되는지 이런건 설명하지 않았다. 그런건 다른곳에서 봐야한다.
이건 미리보기가 뜨는데 왜 백준사이트는 안뜨는지 모르겠다.
int n;
long long a, b, c;
long long x[1000001];
long long prefix_sum[1000001];
long long dp[1000001];
vector<tuple<long long, long long, double>> powers;
사용할 변수들이다.
double get_intersection_x(tuple<long long, long long, double> a, tuple<long long, long long, double> b) {
return (double)(get<1>(a) - get<1>(b)) / (double)(get<0>(b) - get<0>(a));
}
두 선의 교점을 얻는 함수이다. 갑자기 두 선이 왜 나오는지 모르겠다면 Convex Hull Trick 이걸 공부하고 와야한다.
cin >> n;
cin >> a >> b >> c;
for (int i = 0; i < n; i++) cin >> x[i];
prefix_sum[0] = x[0];
for (int i = 1; i < n; i++) prefix_sum[i] = prefix_sum[i - 1] + x[i];
값들을 입력받고 누적합을 준비해준다.
prefix_sum[-1] = 0;
dp[-1] = 0;
for문을 0부터 시작할것이기 때문에 각각 -1을 준비해줬다.
for (int i = 0; i < n; i++) {
tuple<long long, long long, double> new_power =
make_tuple(-2 * a * prefix_sum[i - 1],
a * prefix_sum[i - 1] * prefix_sum[i - 1] - b * prefix_sum[i - 1] + dp[i - 1],
0);
new_power는 새로 집어넣을 전투력의 선이다.
make_tuple안에 들어있는 것들은 위의 공식
a : (- 2 * a * p[j] * p[i])
x : (- 2 * a * p[j] * p[i])
b : (a * p[j] ^ 2 - b * p[j] + dp[j] )
여기에서 a와 b를 계산해 넣은 것이다. 맨뒤의 0은 이 선이 영향을 끼치기 시작하는 x 좌표이다.
while (!powers.empty()) {
tuple<long long, long long, double> last_power = powers.back();
double intersection_x = get_intersection_x(new_power, last_power);
if (intersection_x < get<2>(last_power)) powers.pop_back();
else break;
}
그후 교점을 계산한 다음 새로 추가한 선이 영향을 끼치기 시작하는 x좌표가
이전에 추가된 선이 영향을 끼치기 시작하는 선보다 왼쪽에 있다면, 즉 더 먼저 영향을 끼친다면
이전에 추가된 선을 제거한다.
그게 아니면 whlie문에서 나간다.
if (!powers.empty()) {
tuple<long long, long long, double> last_power = powers.back();
double intersection_x = get_intersection_x(new_power, last_power);
new_power = make_tuple(get<0>(new_power), get<1>(new_power), intersection_x);
}
powers.push_back(new_power);
while문에서 나왔는데 만약 powers (new_power을 집어넣는 vector)가 비어있지 않다면
new_power의 x좌표를 교점으로 바꿔준다.
그리고 추가한다.
long long x = prefix_sum[i];
int idx = powers.size() - 1;
int left = 0, right = powers.size() - 1;
이제 이진탐색을 준비한다.
long long x는
a : (- 2 * a * p[j] * p[i])
x : (- 2 * a * p[j] * p[i])
b : (a * p[j] ^ 2 - b * p[j] + dp[j] )
여기에서 x이다. 즉 p[i]이다.
while (left <= right) {
int mid = (left + right) / 2;
if (get<2>(powers[mid]) <= x) {
idx = mid;
left = mid + 1;
}
else {
right = mid - 1;
}
}
이진탐색 해준다.
dp[i] = get<0>(powers[idx]) * x + get<1>(powers[idx]) + b * prefix_sum[i] + c + a * prefix_sum[i] * prefix_sum[i];
}
마지막으로 계산해서 넣어주면 된다.
아까 위에서 (a * p[i] ^ 2 + b * p[i] + c) 이것들은 나중에 사용된다고 했는데 그게 여기에 사용된다.
(a * p[i] ^ 2 + b * p[i] + c) 이것들은
a : (- 2 * a * p[j] * p[i])
x : (- 2 * a * p[j] * p[i])
b : (a * p[j] ^ 2 - b * p[j] + dp[j] )
여기에 어떤 영향도 주지 않는다.
b에 추가하면 영향을 주지 않을까 생각할수도 있지만
아까 위에서 만든 공식이
(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j] + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c) 이거다.
이건
dp[i] = max( (- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j] + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c) ) 이렇게 된다.
max안에 있다.
그리고 max는 j가 0일때부터 i-1일때까지를 기준으로 한다.
(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j] + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)
이 식에서 j가 영향을 주는 부분은
(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j] + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)
이렇게 뿐이다. 뒤에는 영향을 주지 않는다.
그래서 마지막에 추가해주는 것이다.
(a * p[i] ^ 2 + b * p[i] + c) 이건 선 계산할때 들어오면 안된다.
#include <iostream>
#include <algorithm>
#include <vector>
#include <tuple>
using namespace std;
int n;
long long a, b, c;
long long x[1000001];
long long prefix_sum[1000001];
long long dp[1000001];
vector<tuple<long long, long long, double>> powers;
double get_intersection_x(tuple<long long, long long, double> a, tuple<long long, long long, double> b) {
return (double)(get<1>(a) - get<1>(b)) / (double)(get<0>(b) - get<0>(a));
}
int main(int argc, char** argv)
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL);
cin >> n;
cin >> a >> b >> c;
for (int i = 0; i < n; i++) cin >> x[i];
prefix_sum[0] = x[0];
for (int i = 1; i < n; i++) prefix_sum[i] = prefix_sum[i - 1] + x[i];
prefix_sum[-1] = 0;
dp[-1] = 0;
for (int i = 0; i < n; i++) {
tuple<long long, long long, double> new_power =
make_tuple(-2 * a * prefix_sum[i - 1],
a * prefix_sum[i - 1] * prefix_sum[i - 1] - b * prefix_sum[i - 1] + dp[i - 1],
0);
while (!powers.empty()) {
tuple<long long, long long, double> last_power = powers.back();
double intersection_x = get_intersection_x(new_power, last_power);
if (intersection_x < get<2>(last_power)) powers.pop_back();
else break;
}
if (!powers.empty()) {
tuple<long long, long long, double> last_power = powers.back();
double intersection_x = get_intersection_x(new_power, last_power);
new_power = make_tuple(get<0>(new_power), get<1>(new_power), intersection_x);
}
powers.push_back(new_power);
long long x = prefix_sum[i];
int idx = powers.size() - 1;
int left = 0, right = powers.size() - 1;
while (left <= right) {
int mid = (left + right) / 2;
if (get<2>(powers[mid]) <= x) {
idx = mid;
left = mid + 1;
}
else {
right = mid - 1;
}
}
dp[i] = get<0>(powers[idx]) * x + get<1>(powers[idx]) + b * prefix_sum[i] + c + a * prefix_sum[i] * prefix_sum[i];
}
cout << dp[n - 1] << '\n';
return 0;
}
전체코드다.
'백준' 카테고리의 다른 글
| 백준 11266 단절점 (c++) : 피너클 (1) | 2025.08.11 |
|---|---|
| 백준 11280 2-SAT - 3 (c++) : 피너클 (5) | 2025.07.30 |
| 백준 13263 나무 자르기 (c++) : 피너클 (4) | 2025.07.27 |
| 백준 2295 세 수의 합 (c++) : 피너클 (0) | 2024.11.28 |
| 백준 7785 회사에 있는 사람 (c++) : 피너클 (0) | 2024.11.26 |