피너클의 it공부방

백준 4008 특공대 (c++) : 피너클 본문

백준

백준 4008 특공대 (c++) : 피너클

피너클 2025. 7. 28. 22:22
728x90
반응형

4008번: 특공대

https://www.acmicpc.net/problem/4008

 

dp와 Convex Hull Trick 으로 풀었다.

 

먼저 dp 식을 만들었다.

dp[i] = 0 ~ i까지 특공대를 조직했을때 최대 전투력

대충 이정도로 생각해보자.

 

dp[i]를 만들었다고 생각하고 dp[i + 1]로 넘어가자

dp[i + 1]은 dp[i]에서 새로운 병사가 추가되어 x[i + 1]이 추가되어야 한다.

dp[i + 1] = dp[i] + a * x[i + 1] ^ 2 + b * x[i + 1] + c

 

하지만 새로 추가될 병사인 x[i + 1]은 앞의 인원과 그룹이 될수도 있다.

x[i + 1]과 x[i]이 하나의 특공대가 되게 하려면 dp[i]가 아닌 dp[i - 1]을 사용해야한다.

dp[i + 1] = dp[i - 1] + a * (x[i + 1] + x[i]) ^ 2 + b * (x[i + 1] + x[i]) + c

 

또 앞의 인원과 그룹이 될수 있다.

dp[i + 1] = dp[i - 2] + a * (x[i + 1] + x[i] + x[i - 1]) ^ 2 + b * (x[i + 1] + x[i]  + x[i - 1] ) + c

i+1을 i로 바꾸고 0 ~ i를 j로 바꾸면

이걸 하나의 식으로 만들면 

 

dp[i] = max(dp[j] + a * (x[j + 1 ~ i - 1] + x[i]) ^ 2 + b * (x[j + 1 ~ i - 1] + x[i])  + c)

dp[i] = max(dp[j] + a * (x[j + 1 ~ i]) ^ 2 + b * (x[j + 1 ~ i])  + c)

나는 누적합을 이용할 것이다. i까지의 누적합을 p[i]라고 하면 위의 식을 

dp[i] = max(dp[j] + a * (p[i] - p[j]) ^ 2 + b * (p[i] - p[j])  + c) 이렇게 바꿀수있다.

 

이제 위의 식을 하나하나 풀면

dp[j] + a * (p[i] - p[j]) ^ 2 + b * (p[i] - p[j])  + c

dp[j] + a * (p[i] ^ 2 - 2 * p[i] * p[j] + p[j] ^ 2) + b * (p[i] - p[j]) + c

dp[j] + a * p[i] ^ 2 - 2 * a * p[i] * p[j] + a * p[j] ^ 2 + b * p[i] - b * p[j] + c

이렇게 된다. 이건 직접 종이에 손으로 계산해보는걸 추천한다. 그래야 이해가 된다. 내가 그랬다.

 

위의 식을 p[i]를 기준으로 나눌 것이다.

(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j]  + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)

이것도 직접 손으로 해봐야한다. 그래야 이해가 된다. 내가 그랬다.

 

위의 식에서 p[i]를 변수 x로 보면 1차 방정식이 보일 것이다.

y = ax + b라고 했을때 

a : (- 2 * a * p[j] * p[i])

x : (- 2 * a * p[j] * p[i])

b : (a * p[j] ^ 2 - b * p[j]  + dp[j] )

그럼 뒤에있는 (a * p[i] ^ 2 + b * p[i] + c) 이건 어떻게 될까?

이것들은 나중에 사용된다.

 

이제부터 Convex Hull Trick 이걸 사용해야하는데 이건 미안하지만 다른곳에서 강좌를 봐주길 바란다.

백준 13263 나무 자르기 (c++) : 피너클

 

백준 13263 나무 자르기 (c++) : 피너클

13263번: 나무 자르기Convex Hull Trick 으로 풀었다. 문제를 먼저 봐보자.만약 내가 5번 나무의 높이를 0으로 만들면 5번 보다 번호가 큰 나무를 자르기 전까지는전기톱을 충전할때 b[5]만큼의 비용이

pinacle.tistory.com

이건 Convex Hull Trick이용한 다른 문제인데 여기에서 Convex Hull Trick이거 조금 설명했으니 이걸 봐도 될것이다.

여기서는 진짜 조금 설명했다. 왜 이걸 사용해도 되는지 이런건 설명하지 않았다. 그런건 다른곳에서 봐야한다.

이건 미리보기가 뜨는데 왜 백준사이트는 안뜨는지 모르겠다.

 

int n;
long long a, b, c;
long long x[1000001];
long long prefix_sum[1000001];
long long dp[1000001];
vector<tuple<long long, long long, double>> powers;

사용할 변수들이다.

 

double get_intersection_x(tuple<long long, long long, double> a, tuple<long long, long long, double> b) {
	return (double)(get<1>(a) - get<1>(b)) / (double)(get<0>(b) - get<0>(a));
}

두 선의 교점을 얻는 함수이다. 갑자기 두 선이 왜 나오는지 모르겠다면 Convex Hull Trick 이걸 공부하고 와야한다.

 

	cin >> n;
	cin >> a >> b >> c;
	for (int i = 0; i < n; i++) cin >> x[i];
	prefix_sum[0] = x[0];
	for (int i = 1; i < n; i++) prefix_sum[i] = prefix_sum[i - 1] + x[i];

값들을 입력받고 누적합을 준비해준다.

 

	prefix_sum[-1] = 0;
	dp[-1] = 0;

for문을 0부터 시작할것이기 때문에 각각 -1을 준비해줬다.

 

	for (int i = 0; i < n; i++) {
		tuple<long long, long long, double> new_power =
			make_tuple(-2 * a * prefix_sum[i - 1],
				a * prefix_sum[i - 1] * prefix_sum[i - 1] - b * prefix_sum[i - 1] + dp[i - 1],
				0);

new_power는 새로 집어넣을 전투력의 선이다.

make_tuple안에 들어있는 것들은 위의 공식

a : (- 2 * a * p[j] * p[i])

x : (- 2 * a * p[j] * p[i])

b : (a * p[j] ^ 2 - b * p[j]  + dp[j] )

여기에서 a와 b를 계산해 넣은 것이다. 맨뒤의 0은 이 선이 영향을 끼치기 시작하는 x 좌표이다.

 

		while (!powers.empty()) {
			tuple<long long, long long, double> last_power = powers.back();
			double intersection_x = get_intersection_x(new_power, last_power);

			if (intersection_x < get<2>(last_power)) powers.pop_back();
			else break;
		}

그후 교점을 계산한 다음 새로 추가한 선이 영향을 끼치기 시작하는 x좌표가

이전에 추가된 선이 영향을 끼치기 시작하는 선보다 왼쪽에 있다면, 즉 더 먼저 영향을 끼친다면

이전에 추가된 선을 제거한다.

그게 아니면 whlie문에서 나간다.

 

		if (!powers.empty()) {
			tuple<long long, long long, double> last_power = powers.back();
			double intersection_x = get_intersection_x(new_power, last_power);
			new_power = make_tuple(get<0>(new_power), get<1>(new_power), intersection_x);
		}
        
		powers.push_back(new_power);

while문에서 나왔는데 만약 powers (new_power을 집어넣는 vector)가 비어있지 않다면

new_power의 x좌표를 교점으로 바꿔준다.

그리고 추가한다.

 

		long long x = prefix_sum[i];
		int idx = powers.size() - 1;

		int left = 0, right = powers.size() - 1;

이제 이진탐색을 준비한다.

long long x는 

a : (- 2 * a * p[j] * p[i])

x : (- 2 * a * p[j] * p[i])

b : (a * p[j] ^ 2 - b * p[j]  + dp[j] )

여기에서 x이다. 즉 p[i]이다.

 

		while (left <= right) {
			int mid = (left + right) / 2;
			if (get<2>(powers[mid]) <= x) {
				idx = mid;
				left = mid + 1;
			}
			else {
				right = mid - 1;
			}
		}

이진탐색 해준다.

 

		dp[i] = get<0>(powers[idx]) * x + get<1>(powers[idx]) + b * prefix_sum[i] + c + a * prefix_sum[i] * prefix_sum[i];
	}

마지막으로 계산해서 넣어주면 된다.

아까 위에서 (a * p[i] ^ 2 + b * p[i] + c) 이것들은 나중에 사용된다고 했는데 그게 여기에 사용된다.

(a * p[i] ^ 2 + b * p[i] + c) 이것들은 

a : (- 2 * a * p[j] * p[i])

x : (- 2 * a * p[j] * p[i])

b : (a * p[j] ^ 2 - b * p[j]  + dp[j] )

여기에 어떤 영향도 주지 않는다.

b에 추가하면 영향을 주지 않을까 생각할수도 있지만

아까 위에서 만든 공식이

(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j]  + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c) 이거다.

이건 

dp[i] = max( (- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j]  + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c) ) 이렇게 된다.

max안에 있다.

그리고 max는 j가 0일때부터 i-1일때까지를 기준으로 한다.

 

(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j]  + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)

이 식에서 j가 영향을 주는 부분은

(- 2 * a * p[j] * p[i]) + (a * p[j] ^ 2 - b * p[j]  + dp[j] ) + (a * p[i] ^ 2 + b * p[i] + c)

이렇게 뿐이다. 뒤에는 영향을 주지 않는다.

그래서 마지막에 추가해주는 것이다.

(a * p[i] ^ 2 + b * p[i] + c) 이건 선 계산할때 들어오면 안된다.

 

#include <iostream>
#include <algorithm>
#include <vector>
#include <tuple>

using namespace std;

int n;
long long a, b, c;
long long x[1000001];
long long prefix_sum[1000001];
long long dp[1000001];
vector<tuple<long long, long long, double>> powers;

double get_intersection_x(tuple<long long, long long, double> a, tuple<long long, long long, double> b) {
	return (double)(get<1>(a) - get<1>(b)) / (double)(get<0>(b) - get<0>(a));
}

int main(int argc, char** argv)
{
	std::ios_base::sync_with_stdio(false);
	std::cin.tie(NULL);

	cin >> n;
	cin >> a >> b >> c;
	for (int i = 0; i < n; i++) cin >> x[i];
	prefix_sum[0] = x[0];
	for (int i = 1; i < n; i++) prefix_sum[i] = prefix_sum[i - 1] + x[i];

	prefix_sum[-1] = 0;
	dp[-1] = 0;

	for (int i = 0; i < n; i++) {
		tuple<long long, long long, double> new_power =
			make_tuple(-2 * a * prefix_sum[i - 1],
				a * prefix_sum[i - 1] * prefix_sum[i - 1] - b * prefix_sum[i - 1] + dp[i - 1],
				0);

		while (!powers.empty()) {
			tuple<long long, long long, double> last_power = powers.back();
			double intersection_x = get_intersection_x(new_power, last_power);

			if (intersection_x < get<2>(last_power)) powers.pop_back();
			else break;
		}

		if (!powers.empty()) {
			tuple<long long, long long, double> last_power = powers.back();
			double intersection_x = get_intersection_x(new_power, last_power);
			new_power = make_tuple(get<0>(new_power), get<1>(new_power), intersection_x);
		}

		powers.push_back(new_power);

		long long x = prefix_sum[i];
		int idx = powers.size() - 1;

		int left = 0, right = powers.size() - 1;

		while (left <= right) {
			int mid = (left + right) / 2;
			if (get<2>(powers[mid]) <= x) {
				idx = mid;
				left = mid + 1;
			}
			else {
				right = mid - 1;
			}
		}

		dp[i] = get<0>(powers[idx]) * x + get<1>(powers[idx]) + b * prefix_sum[i] + c + a * prefix_sum[i] * prefix_sum[i];
	}

	cout << dp[n - 1] << '\n';
	return 0;
}

전체코드다.

728x90
반응형
Comments