최종적으로 답은 S(1) ~ S(n) 까지의 합이다.
S(m) 은 민균이가 추가로 갚아야하는 돈이고, 해당 금액을 계산하는 방법은
A 배열에서 m개의 원소를 추출해서 m*(그 원소들중 최댓값) - 원소의 합이다.
즉 최댓값에서 각원소를 뺀 값들을 모두 더한것이다.
가장 쉽게 생각해볼만한건 배열을 임의로 정해서 해당 배열에 대해 S(m)을 구하는 것이다.
n개 배열중 1개로 뽑을 수 있는 경우의 수 nC1 +... nCn 까지 하면 2^n -1 이다.
최대 n은 4000이므로 불가능하다.
S[m] 일때 S[m-1] 을 구하기 위해서는 정렬된 배열에서 가장 작은 원소를 제거하거나, 가장 큰 원소를 제거하거나 이다. (중간에서 제거를 할 경우 가장 작은 원소를 제거한 것 보다 반드시 갚아야할 금액이 커진다)
사실 이 방법으로 문제에서 제시해준 테스트케이스가 모두 통과했어서 뭐가 문젠지 고생했다..
결론적으로 S[m] 을 알고 S[m-1]을 알더라도 S[m-2] 는 해당 방법으로 구할 수 없다.
반례는
#Input
1
5 10 14 5 3 17
#Answer
70
이다.
S[m-1]에 쓰이지 않았던 원소가 S[m-2]에 사용될 수 있기 때문에 계속해서 원소를 삭제해나가는건 틀린답이다.
최댓값이 정해진 배열에서는 가장 차이가 작도록 원소를 정해야한다.
즉, 1 3 4 5 8 에서는
m=3 이고 5가 정해져있을때 1 3 5 가 아니라 5에 가까운 3 4 5로 정해야한다.
해당 포인트를 캐치했다면 누적합을 이용하여 각 S(m)을 구할때 start index를 바꿔가면서 O(n) 시간에 S(m)을 구할 수 있다. 총 O(n^2) 시간이 소요된다.
우리는 start index로 부터 end index까지의 합을 구하고 싶다.
여기서 누적합을 어덯게 빼야 해당 값을 구할 수 있을까?
3번부터 5번 index까지의 합을 구하고 싶다면 : 5번까지의 누적합 - 2번까지의 누적합
1번까지 index까지의 합을 구하고 싶다면 : 1번까지의 누적합 - 0번까지의 누적합
0번까지 index까지의 합을 구하고 싶다면 : 0번까지의 누적합 - (-1번까지의 누적합)
이렇게 계산이 된다.
즉, 누적합 배열은 preSum[0] = 0 으로 만들어 둬야 한다.
presum[0] = 0
presum[1] = 0번 index까지의 합
presum[2] = 1번 index까지의 합
...
결국
3번부터 5번 index까지의 합을 구하고 싶다면 : 5번까지의 누적합 - 2번까지의 누적합 : preSum[5+1] - preSum[3]이 된다.
0번부터 1번까지의 합은 : preSum[1+1]-preSum[0]
0번부터 0번까지 합은 : preSum[0+1] - preSum[0]
start ~ end 까지 합 : preSum[end+1] - preSum[start] 로 기억하자.
이렇게 모든 변수를 나타낼 수 있게 된다.
import java.util.*;
public class Main {
static class Pair {
Integer from;
Integer to;
public Pair(Integer from, Integer to) {
this.from = from;
this.to = to;
}
public Pair() {
}
}
static class Three {
Integer from;
Integer to;
Integer distance;
public Three(Integer from, Integer to, Integer distance) {
this.from = from;
this.to = to;
this.distance = distance;
}
public Three() {
}
}
public static void main(String[] args) {
final int MAX = 10001;
Scanner sc = new Scanner(System.in);
int t = Integer.parseInt(sc.nextLine());
for (int i = 0; i < t; i++) {
Long ret = game(sc);
System.out.println(ret);
}
}
public static Long game(Scanner sc) {
String[] s = sc.nextLine().split(" ");
Integer n = Integer.parseInt(s[0]);
Long[] AList = new Long[n];
Long[] preSum = new Long[n+1];
preSum[0]=0L;
for (int i = 0; i < n; i++) {
Long num = Long.parseLong(s[i + 1]);
AList[i]=num;
}
Arrays.sort(AList);
for (int i = 0; i < n; i++) {
preSum[i+1] = preSum[i]+AList[i];
}
Long ans = 0L;
for (int m = 2; m <= n; m++) {
// 2번부터
// S[m] = ?
Long minResult = Long.MAX_VALUE;
for (int start = 0; start + m - 1 < n ; start++) {
int end = start + m - 1;
Long result = m* AList[end] - getSum(start,end,preSum); // 0 ~ 1
minResult = Math.min(result,minResult);
}
ans += minResult;
}
return ans;
}
public static Long getSum(int start, int end , Long[] preSum){
return preSum[end+1] - preSum[start];
}
}