BOJ 2336 굉장한 학생

LONGNEW·2021년 9월 12일
0

BOJ

목록 보기
269/333

https://www.acmicpc.net/problem/2517
시간 2초, 메모리 192MB

input :

  • N (1 ≤ N ≤ 500,000)
  • 세 개의 줄에는 각 시험에서 1등인 학생부터 N등인 학생

output :

  • '굉장한' 학생의 수를 출력

조건 :

  • 조교는 각각의 시험에서 같은 등수의 학생이 한 명도 없도록 성적을 매겼다.

  • A라는 학생이 B라는 학생보다 세 번의 시험에서 모두 성적이 좋다면, A가 B보다 '대단하다'

  • C라는 학생보다 '대단한' 학생이 한 명도 없으면, C를 '굉장하다'


일단 어려웠고, 펜윅 트리를 이용해서 최솟값을 찾는 방법을 공부하는 계기가 되었다.

최솟값을 찾는 방법
동일한 내용을 가지고 있는 논문에서부터 설명을 이어 온 것 같다.

기본적인 펜윅트리에서는 배열의 원소를 계속 더해서 구간합을 구하지만 이를 다르게 사용해 최솟값을 저장하게 할수도 있다.
특정 구간에서의 최솟값을 찾기 위해서는 트리 2개를 사용해서 겹치지 않는 구간을 찾게 하고 1개의 원소는 따로 저장하는 방식을 사용한다.

해석

일단 펜윅 트리는 제쳐두고 문제의 아이디어를 떠올려 보자.
"굉장한" 학생의 수를 찾으라고 하는데 굉장한 학생이 뭐냐?

C 보다 대단한 학생이 한 명도 없으면 된다.
즉 C랑 X를 비교할 떄 C가 모든 시험을 더 잘 본것이다.
반대로 C보다 시험을 잘 친 애가 없으면 "굉장한" 아이가 되는 것이다.

기본적으로 1등을 한 학생의 경우 자신보다 대단한 학생이 없다고 보면 되고 그 외의 경우가 존재하는데 예제에서의 "5"번 학생의 경우이다.

그래서 예제에서 정답은 기본 1등 + "5"번 학생 = 4가 된다.

근데 이렇게만 보면 이걸 어떻게 풀어야 하는지가 더 애매해진다.
어떻게 비교를 해야 하나 싶었는데

일단 기본적으로 첫 번째 시험의 경우에는 정렬을 통해서 나열 할 수 있다.
그러면 자기보다 시험을 잘 친 놈들을 쉽게 앞에 보낼 수 있다.

그러면 남은 두 시험을 비교해야 하는데 이를 펜윅 트리를 통해 나타내자.

트리

정렬을 통해서 나열했기 때문에 남은 두 개의 숫자 (시험 결과)를 어떻게 활용할 지 생각해야 한다.

이에 대한 아이디어로 인덱스, 밸류를 통해 나타내는 것이다. [1 ~ 두번째 시험 등수] 까지의 구간에서 가장 등수가 높은 놈을 찾아 이게 "세번째 시험 등수" 보다 낮으면 "대단한" 학생이 있는 것이다.

이런 방식을 통해서 트리에는 'inf' 값을 저장하게 한 후에 for문을 통해 각 학생들의 등수를 저장하게 한다.

여기에서 우리가 볼 구간은 [1 ~ 두번째 등수]로 고정되어 있기 때문에 원래의 펜윅 트리로도 충분히 search가 가능하다.

기본적으로 update를 할 때는 idx += (idx & -idx)를 수행하면서 이 값들의 최솟값으로 업데이트를 하기 때문에 빠르게 탐색이 가능하다.

그리고 이에대한 리턴 값(세번째 시험을 가장 잘 친 학생의 등수)을 본인의 등수와 비교해서 자기가 더 높으면 "굉장한" 학생이기 때문에 정답을 += 1 한다.

import sys


def update(idx, val):
    """
        idx위치의 등수를 val로 업데이트
        가장 등수가 높은 애를 저장하기 위해 min을 사용한다.
    """
    while idx <= n:
        tree[idx] = min(tree[idx], val)
        idx += (idx & -idx)


def query(end):
    """
        구간은 1 ~ end로 고정되어 있다.
        각 위치에서 최소값을 찾을 수 있도록 하면 된다.
    """
    ret = float('inf')
    while end > 0:
        ret = min(ret, tree[end])
        end -= (end & -end)

    return ret

n = int(sys.stdin.readline())
student = [[] for _ in range(n + 1)]
tree = [float('inf')] * (n + 1)

student[0] = [0, 0, 0]

for _ in range(3):
    temp = list(map(int, sys.stdin.readline().split()))

    for i in range(1, n + 1):
        num = temp[i - 1]
        student[num].append(i)

student.sort(key=lambda x:x[0])

ans = 0
for i in range(1, n + 1):
    sec_exam, thi_exam = student[i][1], student[i][2]
    former_rank = query(sec_exam)

    if former_rank > thi_exam:
        ans += 1

    update(sec_exam, thi_exam)

print(ans)

0개의 댓글