_, N, *ls = map(int, open(0).read().split())
l, r = 1, max(ls)
f = lambda x: x//d
result = 0
while l <= r:
d = (l + r) // 2
s = sum(map(f, ls))
if s < N:
r = d - 1
elif s >= N:
if d > result:
result = d
l = d + 1
print(result)
메모리 약 32MB, 시간 72ms
먼저 탐색할 범위를 정한다. 어떤걸 범위로 지정할 것인가? 이 질문에 답하기 위해 다른 질문에 대한 답을 명확히 해야 한다 : "무엇을 구해야 하는가?" 내가 구하고자 하는 것은 잘라진 랜선의 길이이다. 그런데 길이를 잘 구했는지 알기 위한 기준으로 사용할 수 있는 것은 잘라진 랜선의 개수이다. 랜선의 개수는 주어진 랜선 길이들을 자를 길이로 나눈 몫이다. 랜선 개수를 s
라 하고 랜선 길이를 d
라 한다.
랜선 개수를 구하기 위한 함수는 f
에 람다 함수를 할당해 다음과 같이 사용했다.
f = lambda x: x//d
s = sum(map(f, ls))
주어진 랜선들을 어떤 길이 d
로 잘랐을 때 나머지를 버린 수를 합하여 s
에 할당한다. 이 값을 비교해야 할 값은 N
이다: 랜선을 잘라 N
개의 랜선으로 만들었는가?
s
가 N
보다 작으면 d
를 줄여야 한다. 그래야 s
가 커진다. s
가 N
보다 크면 d
를 늘려야 한다. 그래야 s
가 작아진다. s
가 N
보다 같으면? d
를 늘려야 한다. 가능한 가장 큰 d
를 구하고 있기 때문이다. if s < N:
r = d - 1
elif s >= N:
if d > result:
result = d # 최대값을 result에 저장한다
l = d + 1
반복문의 각 단계마다 l
과 r
의 값을 변화시키며 d
를 늘이거나 줄인다. l
이 r
보다 커질 때 반복문을 종료하고, 저장된 최대값을 출력한다.
from sys import stdin
input = stdin.readline
k, n = map(int, input().rstrip().split())
l = [int(input()) for _ in range(k)]
M = sum(l) // n
m = 1
while m <= M:
mid = (m + M) // 2
if sum([i // mid for i in l]) >= n:
m = mid + 1
else:
M = mid - 1
print(M)
메모리 약 31MB, 시간 52ms
크게 범위와 출력 값이 눈에 띈다.
나는 탐색의 범위를 1부터 max(ls)
까지로 했다. sum(l)//n
까지 해도 되는 이유는 무엇일까? 평균값보다 커질 경우가 없다? 30과 100단위 길이의 랜선이 있다고 하자. 평균은 65이다. N
이 1이라면 출력 값은 100이어야 한다. 그런데 상위 바운더리인 M
이 sum(l)//n
부터 작아진다면, 100이 나올 수 없다. 아니나 다를까 테스트해보니 30이 출력된다. 내 코드가 맞았다.
조건을 다시 보니 항상 K <= N
이라고 한다. 위 케이스에서 N
이 2 이상이어야 한다는 뜻이다. 2개만이라도 50이 출력되기 때문에, 65에서 시작해 작아져도 충분하다. 랜선 K
개에서 시작해 적어도 K
개를 잘라내야 하기 때문에 평균에서 시작할 수 있다.
최대값을 이렇게도 구할 수 있는가? 합이 n
보다 같거나 크면 아래쪽 바운더리를 중간값보다 1 더 크게 하고, 아니라면 위쪽 바운더리M
을 중간값보다 1 작게 잡는다. 그리고 M
을 출력한다. 나는 처음에 중간값 d
를 출력하려 했었다. 그래서 자꾸 다른 답이 나와서 result
라는 변수를 따로 마련하고 d
의 최대값을 저장해 해결했다. 만약 상위 바운더리가 정답보다 1만 크더라도 잘라진 랜선의 개수가 구하려는 값보다 작다. 조금씩 줄어들다가 같아지는 순간 M
은 제자리에 있고 m
이 올라가 최대 길이의 M
을 구할 수 있다.