[Spring] Redis 를 사용한 Rate Limit Filter 구현

식빵·2024년 3월 25일
0

Spring Lab

목록 보기
32/35
post-thumbnail

Spring Legacy(Spring boot 를 사용하지 않는 환경)에서 작성한 코드 및 설정입니다.
Spring Boot 쓰시는 분들은 Filter 등록법을 조금만 검색해서
이 글을 응용하시면 될 듯합니다 😎 (절대 귀찮아서 안쓰는 거 아닙니다.)


web.xml 일부

  <filter>
    <filter-name>encodingFilter</filter-name>
    <filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class>
    <init-param>
      <param-name>encoding</param-name>
      <param-value>utf-8</param-value>
    </init-param>
  </filter>
  <filter-mapping>
    <filter-name>encodingFilter</filter-name>
    <url-pattern>/*</url-pattern>
  </filter-mapping>


  <!-- 
	RemoteIpFilter 는 proxy 서버가 앞단에 있을 때, Http Header 의 값들을
	분석해서 실제 요청을 보낸 클라이언트의 IP 를 알아내서 request.remoteAddr
	에 값을 설정해주는 편의성 filter 입니다.
  <!-- 
	RemoteIpFilter 는 아마 filter-class 에 작성하면 IDE 가 빨간줄을 그을텐데,
	그럴 때는 maven dependecy 로 ...

	groupId: org.apache.tomcat
	artifactId: tomcat-catalina
	version : 여러분의 톰캣 버전
	scope: provided 

	... 를 추가해주시기 바랍니다.
  -->
  <filter>
    <filter-name>RemoteIpFilter</filter-name>
    <filter-class>org.apache.catalina.filters.RemoteIpFilter</filter-class>
  </filter>

  <filter-mapping>
    <filter-name>RemoteIpFilter</filter-name>
    <url-pattern>/*</url-pattern>
    <dispatcher>REQUEST</dispatcher>
  </filter-mapping>


  <!-- 이게 바로 RATE LIMIT FILTER!! -->
  <filter>
    <filter-name>redisRateLimitFilter</filter-name>
    <filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
  </filter>
  <filter-mapping>
    <filter-name>redisRateLimitFilter</filter-name>
    <url-pattern>/remote-api/*</url-pattern>
    <dispatcher>REQUEST</dispatcher>
  </filter-mapping>




Rate Limit Filter 구현

RedisRateLimitFilter 를 Bean 으로 등록해야 됩니다.
Bean ID (또는 명칭) 은 반드시 web.xml 에서 표기한 filter-name 과 동일하게
"redisRateLimitFilter" 로 명시해야 됩니다. 너무 쉬운 파트니 이건 Skip!

주의: java 17, jakarta ee 을 사용한다는 점 유의하셔서 보시기 바랍니다.

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.util.StringUtils;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
 * <h2>Redis 기반 Rate Limiting Filter</h2>
 * Redis 와 Lua Script 를 활용한 RateLimit 기능을 제공하는 Servlet Filter 이다.<br>
 * 주목적은 DDOS 공격 방어용으로 만들었으며,<br>
 * 특정 시간 (WINDOW_TIME_FRAME) 동안<br>
 * 최대 요청 횟수(MAX_REQUEST_PER_WINDOW) 를 제한한다.<br>
 */
public class RedisRateLimitFilter implements Filter {
    private static final Logger LOGGER 
    	= LoggerFactory.getLogger(RedisRateLimitFilter.class);

    /**
     * Spring Application 에서 DI 받은 RedisTemplate bean instance
     */
    private final RedisTemplate<String, Object> redisTemplate;

    /**
     * Lua Script Wrapper
     */
    private final RedisScript<Long> script;

    /**
     * Window 하나당 요청할 수 있는 최대 요청 수
     */
    private final String MAX_REQUEST_PER_WINDOW = "8";

    /**
     * Window 의 크기 (=Window 하나의 유지시간, 초단위)<br>
     * 너무 큰 값을 주지 않도록 주의 바람.
     */
    private final String WINDOW_TIME_FRAME = "1";

	// 1초에 8번을 초과해서 요청하면 막는다

    /**
     * lua script 에 사용될 고정된 인자값
     */
    private final Object[] SCRIPT_ARGS 
    	= Arrays.asList(MAX_REQUEST_PER_WINDOW, WINDOW_TIME_FRAME).toArray();

    /**
     * 기본 에러 문구
     */
    private final String DEFAULT_ERROR_MSG 
    	= "짧은 시간 내에 너무 많은 요청을 보냈습니다. 잠시 기다렸다 다시 요청해주세요.";


    // ApplicationContext 에서 미리 생성한  RedisTemplate 인스턴스
    //  bean 을 주입받습니다.
    public RedisRateLimitFilter(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
        // redis lua script 사용
        String rateLimitScript =
                "local current = redis.call('get', KEYS[1]) " +
                "if current then " +
                    "if tonumber(current) >= tonumber(ARGV[1]) " +
                        "then return 0 " +
                    "else " +
                        "redis.call('incr', KEYS[1]) " +
                        "redis.call('expire', KEYS[1], ARGV[2]) " +
                        "return 1 " +
                    "end " +
                "else " +
                    "redis.call('set', KEYS[1], 1) " +
                    "redis.call('expire', KEYS[1], ARGV[2]) " +
                    "return 1 " +
                "end";
       
        // 설명:
        // 1. 먼저 KEYS[1] (= ip + 요청 uri 를 합친 문자열) 을 조회한다.
        //
        // 2-1. 만약에 조회가 안되면(끝에 있는 else) 신규로 해당 KEYS[1] 에 
        //		"1" 이라는 문자열값을 주고, EXPIRE 값(초 단위)도 준다.
        //
        // 2-2. 만약에 조회가 된다면 (if current then)
        //
        //      2-2-1. 읽어온 값(=current) 가 ARGV[1] (= Maximum 요청 제한 횟수) 
        //			   을 같거나, 넘으면 0 을 반환한다.
        //
        //      2-2-2. 그게 아니라면 current 값을 1 증가시키고, expire 시간을
        //			    재조정한다. 그리고 나서 1을 반환한다.
        //
        // * 참고: 이 모든 과정은 하나의 트랜잭션 내에서 일어난다. 
        //		  Redis + lua script 의 기본 동작 방식이다.

        // Lau Script Wrapper 생성
        script = new DefaultRedisScript<>(rateLimitScript, Long.class);
    }


    @Override
    public void doFilter(ServletRequest request, 
    					 ServletResponse response, 
                         FilterChain chain) throws IOException, ServletException {

        if (request instanceof HttpServletRequest servletRequest
        		&& response instanceof HttpServletResponse servletResponse) {

            String remoteAddr = servletRequest.getRemoteAddr();
            String requestURI = servletRequest.getRequestURI();
            String method = servletRequest.getMethod();

            // 등록(=POST) 요청만 체크하겠다. 
            if ("POST".equalsIgnoreCase(method)) {
				
                // (중요) 막는 타깃은 [IP + 세션] 입니다!
                String sessionId = request.getSession(true);
                
                String clientKey 
                	= "rate_limit:" 
                    + remoteAddr + ":" 
                    + sessionId + ":" 
                    + requestURI;
                    
                List<String> keys = Collections.singletonList(clientKey);
                Long result = redisTemplate.execute(script, keys, SCRIPT_ARGS);

                if (result != null && result == 0) {
                    LOGGER.error("Too Many Request! [ ip: {} , url: {} ]",
                    			  remoteAddr, requestURI);
                    servletResponse.setStatus(429);
                    servletResponse.setCharacterEncoding("UTF-8");
                    servletResponse.setHeader("Content-Type", "application/json");
                    String errorResponseJson = createErrorResponseJson(requestURI);
                    servletResponse.getWriter().write(errorResponseJson);
                    return;
                }
            }
        }
        chain.doFilter(request, response);
    }

    /**
     * error Message 를 담는 Json String 을 반환한다.
     * @param currentRequestUrl 현재 에러를 일으키는 요청 URL
     * @return 에러 문구를 담는 json 포맷 형식의 string
     */
    private String createErrorResponseJson(String currentRequestUrl) {
        // java 17 text block 문법 사용
        return """
          {"errorMsg" : "%s","blockedUrl" : "%s"}"""
          .formatted(DEFAULT_ERROR_MSG, currentRequestUrl);
    }

}

주의사항

LAN 환경에서는 여러 호스트가 공용으로 사용하는 하나의 외부 IP 를 갖는 경우가
많습니다. 이런 경우를 생각해서 위의 코드에서는 타깃 코드(=clientKey)를 생성할 때
절대로 IP 만으로 막으면 안됩니다!

// 막는 타깃이 [IP] 일 경우
String clientKey = "rate_limit:" + remoteAddr + ":" + requestURI;

그러니 꼭 아래처럼 세션(또는 각 호스트를 구별할 수 있는 어떤것이든)값을 하나
끼워서 clientKey 를 생성하시기 바랍니다!

String sessionId = request.getSession(true);
String clientKey = 
		"rate_limit:" + 
        remoteAddr + ":" + 
        sessionId + ":" // 이게 핵심!
        + requestURI;
profile
백엔드를 계속 배우고 있는 개발자입니다 😊

0개의 댓글