[Scala] Currying

오현우·2022년 12월 25일
0

Currying Pattern

함수형 프로그래밍에서 Currying이라는 디자인 패턴을 사용하여 합성함수를 만들어낸다.

예시를 들어서 보자.

우리는 (a, b)라는 파라미터를 넘겨서 해당 숫자의 범위 만큼 어떠한 함수를 적용시켜 모두 더해주는 함수를 작성하려고 한다.

구체적인 예시를 들어보자.

def sum(f: Int => Int): (Int, Int) => Int = {
  def sumF(a: Int, b: Int): Int = {
    if a > b then 0
    else f(a) + sumF(a+1, b)
  }
  sumF
}

위의 함수는 우리가 각각의 원소에 적용할 함수를 넣어주면 해당 함수를 적용시켜 모두 더해주는 함수를 리턴해주는 함수이다.

위의 함수는 아래와 같이 사용할 수 있다.

def sumInt = sum(x => x) 
def sumSqure = sum(x => x * x)
def sumCube = sum(x => x * x * x)

sumInt(1, 10)
sumSqure(1, 10)
sumCube(1, 10)

여기서 주목할 만한 파트가 존재한다.

sum(function)(a, b) 로 합성이 가능하다.

따라서 스칼라에서는 아래와 같은 커링 디자인 패턴이 존재하여 위와 같은 경우를 더 깔끔하게 만들어준다.

def sum(f: Int => Int)(a: Int, b: Int): Int = {
  if a > b then 0
  else f(a) + sum(f)(a+1, b)
}

val temp: Int = sum(x => x * 3 + 5)(1, 2)
println(temp)

여기서 조금 더 응용을 해보자.
우리는 팩토리얼 함수를 구현하려고 한다.
그럼 범위 product 함수를 먼저 정의해보자.

def product(f: Int => Int)(a: Int, b: Int) =
	if a > b then 1
    else f(a) * product(f)(a+1, b)

factorial(n: Int) = product(x => x)(1, n)

factorial(5) // will return 120

MapReduce 구현하기

이제 위의 product와 sum을 일반화하여 mapReduce로 구현하여 보자.

def mapReduce(f: Int => Int, combine: (Int, Int) => Int, zero: Int)(a: Int, b: Int) = {
	def recur(a: Int): Int =
    	if a > b then zero
    	else combine(f(a), recur(a+1))
  	recur(a)
}

def sumInt(f: Int => Int) = mapReduce(f, (x, y) => x+y, 0)
sumInt(x => x)(1, 5)

def productInt(f: Int => Int) = mapReduce(f, (x, y) => x*y, 1)
productInt(x => x)(1, 5)

FIXED Point 함수형으로 구현하기

부동점은 f(x) => x 를 만족하는 x의 집합이다.
아래 함수는 부동점을 구하기 위해 근사하는 방법을 정의하였다.

import annotation.tailrec

val tolerance: Float = 0.0001

def abs(x: Double) = if x >= 0 then x else -x

def isClosedEnough(x: Double, y: Double): Boolean =
  abs((x - y) / x) < tolerance

def fixedPoint(f: Double => Double)(firstGuess: Double): Double = {
  @tailrec
  def iterate(guess: Double): Double = {
    val next = f(guess)
    if isClosedEnough(guess, next) then next
    else iterate(next)
  }
  iterate(firstGuess)
}

def sqrt(x: Double) = fixedPoint(y => x / y)(1.0)

@main
def test =
  sqrt(2)

하지만 위에서 문제가 발생하는데, 해당 부분은 근사값들이 수렴하지 않고 진동하기 때문에 발생하였다.

때문에 수렴하는 속도를 조금 조절해주면 원할하게 작동한다.

def sqrt(x: Double) = fixedPoint(y => (y + x / y) / 2)(1.0)

공부를 마치며 개인적인 생각

함수형 프로그래밍에서 커링은 매우 중요한 개념으로 느껴졌다.
실제 수학적 개념이 많이 혼용된 느낌이 강했다.
특히 체인룰을 이용하여 함수 => 함수 => 결과를 리턴하는 방식이 정말 수학적으로 느껴졌다.

또 파이썬의 래핑함수와 비슷하다는 느낌을 받았다.

함수형 프로그래밍을 단순하면서 복잡하게 구현가능한 디자인 패턴 커링에 대해 리뷰를 마친다.

profile
핵심은 같게, 생각은 다르게

0개의 댓글