Leetcode 3321. Find X-Sum of All K-Long Subarrays II

Alpha, Orderly·2025년 11월 9일
0

leetcode

목록 보기
175/176

문제

You are given an array nums of n integers and two integers k and x.

The x-sum of an array is calculated by the following procedure:

Count the occurrences of all elements in the array.
Keep only the occurrences of the top x most frequent elements. 
If two elements have the same number of occurrences, 
the element with the bigger value is considered more frequent.
Calculate the sum of the resulting array.
Note that if an array has less than x distinct elements, its x-sum is the sum of the array.

Return an integer array answer of length n - k + 1 where answer[i] is the x-sum of the subarray nums[i..i + k - 1].
  • 정수 배열 nums와 두 정수 k, x 가 주어진다.
  • k길이의 sliding window를 nums에서 순회하며 다음 값을 찾는다.
    • window 안의 모든 숫자를 카운팅
    • 숫자들중 가장 많이 나온 숫자를 x개 선택한다.
    • 그 숫자들의 합
  • 해당 구한 합을 ans로 배열로 만들어 리턴하시오
  • 다만, 등장한 횟수가 동일하다면 큰 숫자를 선택한다.

예시

nums = [1,1,2,2,3,4,2,3], k = 6, x = 2

Subarray 1: [1, 1, 2, 2, 3, 4]

  • 상위 2개 빈도 원소: 1, 2
  • 합: 1 + 1 + 2 + 2 = 6

Subarray 2: [1, 2, 2, 3, 4, 2]

  • 상위 2개 빈도 원소: 2, 4
  • 합: 2 + 2 + 2 + 4 = 10

Subarray 3: [2, 2, 3, 4, 2, 3]

  • 상위 2개 빈도 원소: 2, 3
  • 합: 2 + 2 + 2 + 3 + 3 = 12

제한

  • nums.length==nnums.length == n
  • 1<=n<=1051 <= n <= 10^5
  • 1<=nums[i]<=1091 <= nums[i] <= 10^9
  • 1<=x<=k<=nums.length1 <= x <= k <= nums.length

풀이

from sortedcontainers import SortedList

class Solver:
    def __init__(self, limit: int):
        self.r = SortedList()
        self.t = SortedList()
        self.c = 0
        self.f = dict()

        self.limit = limit

    def update(self, val: int, update: int):
        if val in self.f:
            if (self.f[val], val) in self.t:
                self.t.remove((self.f[val], val))
                self.c -= self.f[val] * val
            else:
                self.r.remove((self.f[val], val))

        if val not in self.f:
            self.f[val] = 0

        self.f[val] += update

        if self.f[val] == 0:
            del self.f[val]
        else:
            self.t.add((self.f[val], val))
            self.c += self.f[val] * val

        self.balance()

    def balance(self):
        if len(self.t) > self.limit:
            f, v = self.t.pop(0)
            self.c -= f * v
            self.r.add((f, v))

        if len(self.t) < self.limit and self.r:
            f, v = self.r.pop()
            self.c += f * v
            self.t.add((f, v))

        if self.t and self.r and self.t[0] < self.r[-1]:
            top = self.t.pop(0)
            remain = self.r.pop()

            self.t.add(remain)
            self.r.add(top)

            self.c = self.c - top[0] * top[1] + remain[0] * remain[1]




class Solution:
    def findXSum(self, nums: List[int], k: int, x: int) -> List[int]:
        ans = []

        s = Solver(x)

        for i in range(k):
            s.update(nums[i], 1)

        ans.append(s.c)

        for i in range(k, len(nums)):
            s.update(nums[i], 1)
            s.update(nums[i - k], -1)
            ans.append(s.c)

        return ans

슬라이딩 윈도우 내에서 각 원소의 빈도를 계속 갱신하면서, 그중 상위 x개의 (freq, val)만 유지하고 이들의 freq * val 합을 효율적으로 계산하는 방식으로 동작합니다.

정렬된 두 그룹을 사용합니다:

  • top(t): 빈도가 높은(동률이면 값이 큰) 상위 x개의 원소
  • rest(r): 나머지 원소

원소를 추가하거나 제거하면 빈도가 바뀌므로, 이전 위치(top 또는 rest)에서 제거한 뒤 새 빈도로 다시 넣습니다. 이때 일단 임시로 top에 넣고 균형을 맞춰줍니다.

균형(balancing)은 다음 규칙으로 유지됩니다:

  • top의 크기가 x보다 크면 가장 약한 원소를 빼서 rest로 이동
  • top의 크기가 x보다 작으면 rest에서 가장 강한 후보를 top으로 올림
  • 경계 비교로 rest에 더 강한 항목이 있으면 서로 교환하여 항상 top이 상위 x개가 되도록 보장

top 안의 원소들에 대해 freq * val의 합을 c로 유지하기 때문에, 슬라이딩 이동 시 전체 정렬을 다시 하지 않고도 O(log U) 비용으로 결과를 업데이트할 수 있습니다.

profile
만능 컴덕후 겸 번지 팬

0개의 댓글