백준2805

Ena JJJ·2022년 11월 3일
0

나무자르기

이진탐색으로 접근하는 문제이다. 처음 생각한 접근 방법은 예제를 입력받은 부분을 소팅한다음 이진탐색을 실행하려 했지만

시간 제한 때문에 이 방법이 불가능할 것이라 생각했다.
따라서 새로운 접근법이 필요했다.

나무를 자를 때, 최대 값만 알고 있다면 중간 값을 구할 수 있기 때문에, 각 케이스들의 나무를 중간값으로 잘라 합을 얻어 이를 통해 접근하려 했다.

import sys

input = sys.stdin.readline

n,m = map(int,input().split())

l = list(map(int,input().split()))


def binary_search(left,right):
    l = left
    r = right
    
    if l <= r:
        tmp =0
        mid = (l+r)//2
        
        tmp = check(mid)
        if tmp == m : print(mid); return
        elif tmp > m: l = mid
        else :#tmp < m 
            r = mid
        binary_search(l,r)
        

def check(k):
    sum = 0
    for j in l:
        if j > k:
            sum += j -k
    return sum

binary_search(0,max(l))

따라서 잘라야할 나무들 중 최대 값을 알고 있다면, 중간 값으로 나무들을 잘라 check함수를 통해 체크 후 결과를 도출하려 했다. 하지만 시간초과 오류가 발생했다.

시간초과 이유는 나무를 잘랐을 때 입력받은 값이 나오지 않을 수 있기 때문에 해당 값을 체크할 때, ==이 아닌

    if tmp >= m: l = mid
        else :#tmp < m 
            r = mid
        binary_search(l,r)

위 같이 크거나 같을 때로 구분하여 실행해 주어야 한다.

아래 코드는 위의 방식을 적용하여 값을 직접적으로 비교하게 설정했다.

import sys

input = sys.stdin.readline

n,m = map(int,input().split())

l = list(map(int,input().split()))


def binary_search(left,right):
    l = left
    r = right
    
    if l+1 < r:
        tmp =0
        mid = (l+r)//2
        
        tmp = check(mid)
        
        if tmp >= m: l = mid
        else :#tmp < m 
            r = mid
        binary_search(l,r)
    else:
        print(l)

def check(k):
    sum = 0
    for j in l:
        if j > k:
            sum += j -k
    return sum

binary_search(0,max(l))

이와 다른 접근법으로는 값을 직접적으로 비교하는 것이 아닌, 내가 얻으려는 나무길이 보다 길거나 같다면 True, 짧다면 False를 반환하는 방식으로
0~20 사이일 때, TTTTTFFFF와 같은 방식으로 구현했다. 이에 T와 F가 교차할때의 마지막 T가 최대로 자를 수 있는 높이로 아래와 같이 구현했다.
import sys

input = sys.stdin.readline

n,m = map(int,input().split())

l = list(map(int,input().split()))

def binary_search(left,right):
l = left
r = right

if l+1 < r:
    mid = (l+r)//2
    
    tmp = check(mid)
    if tmp : l = mid
    else : r = mid
    
    binary_search(l,r)
else:
    print(l)

def check(k):
sum = 0
for j in l:
if j > k:
sum += j -k
if sum >= m : return True
return False

binary_search(0,max(l))

0개의 댓글