백준 7453 합이 0인 네 정수

1c2·2025년 2월 10일
0

baekjoon

목록 보기
29/33

문제 링크

Brute Force로 풀면 시간 복잡도가
O(N4)O(N^4)입니다.
N이 최대 4000이므로 시간초과가 납니다.

hash map 사용

처음에는 hash map을 사용해서 풀었습니다.
A와 B의 합이 저장된 hash map을 만들고, C와 D의 합을 구할때마다 이 -1을 곱해서 이 hashmap에 있는지 확인하는 방법입니다.

#include <bits/stdc++.h>
using namespace std;
#define ll long long

int n;
ll A[4000], B[4000], C[4000], D[4000];

void solution() {
    unordered_map<ll, int> AB_map;
    ll ans = 0;

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            AB_map[A[i] + B[j]]++;
        }
    }

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            ll target = -(C[i] + D[j]);
            if (AB_map.find(target) != AB_map.end()) {
                ans += AB_map[target];
            }
        }
    }

    cout << ans << "\n";
}

void input() {
    cin >> n;
    for (int i = 0; i < n; i++) {
        cin >> A[i] >> B[i] >> C[i] >> D[i];
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    input();
    solution();
    return 0;
}

시간복잡도는 O(N2)O(N^2)이 걸려야 하는데 의외로 시간초과가 났습니다. 이유를 찾아보니 Hash가 고르게 분배되지 않은 TC를 넣은것 같습니다. 다른 unordered_map을 쓴 코드를 찾아보니 chaining같은 기법을 적용했습니다.

문제를 푸는 방법을 제한해놓다니 억울하지만 다른 풀이를 생각했습니다.

투 포인터를 사용하면 O(N2)O(N^2), 이분탐색을 사용하면 O(N2logN)O(N^2 logN)이 걸릴 것이라고 생각했고, 두 방법모두 시간초과가 나지 않지만 구현이 쉽고 빠른 투 포인터를 사용해서 풀었습니다.

투 포인터

#include <bits/stdc++.h>
using namespace std;
#define ll long long

int n;
ll A[4000], B[4000], C[4000], D[4000];

void solution() {
    ll ans = 0;
    vector<ll> AB_sum, CD_sum;
    for(int i = 0; i < n;i++){
        for(int j = 0; j < n;j++){
            AB_sum.push_back(A[i] + B[j]);
            CD_sum.push_back(C[i] + D[j]);
        }
    }

    sort(AB_sum.begin(), AB_sum.end());
    sort(CD_sum.begin(), CD_sum.end());

    int AB_idx = 0;
    int CD_idx = CD_sum.size()-1;
    while(AB_idx < AB_sum.size() && CD_idx >= 0){
        ll sum = AB_sum[AB_idx] + CD_sum[CD_idx];
        if(sum == 0){
            int AB_cnt = 0;
            int CD_cnt = 0;
            ll target_AB = AB_sum[AB_idx];
            ll target_CD = CD_sum[CD_idx];

            while(AB_idx < AB_sum.size() && AB_sum[AB_idx] == target_AB) {
                AB_idx++;
                AB_cnt++;
            }
            while(CD_idx >= 0 && CD_sum[CD_idx] == target_CD) {
                CD_idx--;
                CD_cnt++;
            }
            ans += AB_cnt * CD_cnt;
        }else if( sum > 0){
            CD_idx--;
        }else{
            AB_idx++;
        }
    }
    cout << ans << endl;
}

void input() {
    cin >> n;
    for (int i = 0; i < n; i++) {
        cin >> A[i] >> B[i] >> C[i] >> D[i];
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    input();
    solution();
    return 0;
}

0개의 댓글