Django 테스트코드 단에서 n+1 문제 찾아내기

런던행·2020년 12월 2일
0

TDD in Django

목록 보기
2/2
post-thumbnail

ORM를 사용하는 프레임워크는 동일하게 N+1 문제가 발생 할 수 있다. 이 포스트글에서는 테스트코드에서 N+1문제를 예측하고 수정하는 방법을 공유하고자 한다.

각 프레임워크 마다 n+1 문제를 detection 하는 패키지들은 이미 휼륭한 개발자들 의해서 공유되고 있다. 장고에서는 https://github.com/jmcarp/nplusone 패키지가 있으며 이를 이용하고자 한다.

우선 패키지를 설치하다. 설치방법은 패키지의 readme를 참고한다.
아래 예제 코드는 리소스 기반으로 간단한 rest api를 작성한 샘플 코드이다. 
모델을 2개가 있으면 Article - Comment 관계는 일대다이다.

Model

from django.db import models

# Create your models here.
class Article(models.Model):
    # user = models.ForeignKey(User, on_delete=models.CASCADE)
    title = models.CharField(max_length=144)
    subtitle = models.CharField(max_length=144, blank=True)
    content = models.TextField()
    created_at = models.DateTimeField(auto_now_add=True)

    def __str__(self):
        return '[{}] {}'.format(self.title, self.subtitle)


class Comment(models.Model):
    article = models.ForeignKey(to=Article, related_name="comments", on_delete=models.CASCADE)
    content = models.TextField()

일대다 관계 모델이다. Article - Comment

View

from django.shortcuts import render

# Create your views here.
from requests import Response
from rest_framework import viewsets
from .serializers import ArticleSerializer, CommentSerializer
from .models import Article, Comment
from rest_framework import permissions

class ArticleView(viewsets.ModelViewSet):
    queryset = Article.objects.all()
    serializer_class = ArticleSerializer
    permission_classes = ()

    def list(self, request, *args, **kwargs):
        return super().list(request, *args, **kwargs)

    def get_serializer(self, *args, **kwargs):
        return super().get_serializer(*args, **kwargs)

    def get_serializer_context(self):
        context = super(ArticleView, self).get_serializer_context()

        return context


class CommentView(viewsets.ModelViewSet):

    serializer_class = CommentSerializer
    permission_classes = ()

    def get_queryset(self):
        return Comment.objects.filter(article=self.kwargs['article_pk'])

    def perform_create(self, serializer):
        serializer.save()

rest 형식으로 리소스를 반환해준다.

Serializer

from rest_framework import serializers
from .models import Article, Comment
from django.contrib.auth.models import User


class CommentSerializer(serializers.ModelSerializer):
    class Meta:
        model = Comment
        fields = (
            'id',
            'article',
            'content'
        )


class ArticleSerializer(serializers.ModelSerializer):
    comments = CommentSerializer(many=True, read_only=True)
    class Meta:
        model = Article
        fields = (
            'id',
            'title',
            'subtitle',
            'content',
            'created_at',
            'comments',
        )
        read_only_fields = ('created_at',)

ArticleSerializer에서 자식 comments 모델들을 가져온다 (중요, prefetch 안 되어 있다.)

Urls

from django.urls import path, include
from rest_framework_nested import routers

from .views import ArticleView, CommentView

router = routers.SimpleRouter()
router.register(r'articles', ArticleView, basename='articles')

articles_router = routers.NestedSimpleRouter(router, r'articles', lookup='article')
articles_router.register(r'comments', CommentView, basename='article-comments')
urlpatterns = [

    path('', include(router.urls)),
    path('', include(articles_router.urls))
]

TestCode

from unittest import mock

import pytest
from django.conf import settings
from django.test import Client

from api.models import Article, Comment


@pytest.fixture
def logger(monkeypatch):
    mock_logger = mock.Mock()
    monkeypatch.setattr(settings, 'NPLUSONE_LOGGER', mock_logger)
    return mock_logger

def check_nplusone_problem(logger):
    if len(logger.log.call_args_list) != 0:  # 에러 문자가 포함되어 있다.
        args = logger.log.call_args[0]
        assert ("Potential n+1 query detected on" in args[1]) is False  # prefetch 가 필요한 경우
        assert ("Potential unnecessary eager load detected on" in args[1]) is False  # 쓸데없이 prefetch를 한 경우

    assert not logger.log.called  # 정상적인 경우 호출 하면 안 된다.



def create_mock_article() -> Article:
    """
    Article 목 데이터를 생성한다.
    :return:
    """
    # Given
    name: str = "TEST_title"
    description: str = "description"

    # When
    article: Article = Article()
    article.title = name
    article.subtitle = 'subtitle'
    article.content = description
    article.save()

    comment: Comment = Comment()
    comment.content = "AAA"
    comment.article = article
    comment.save()

    return article


@pytest.mark.django_db
def test_route_article_GET(logger) -> None:
    """
    AlbumLV 뷰의 url method GET 테스트 할 수 있다.
    :return:
    """

    # Given
    create_mock_article()
    create_mock_article()
    create_mock_article()

    client = Client()


    # When
    response = client.get('/api/articles/')


    # Then
    print(response.data)
    check_nplusone_problem(logger)
    assert response.status_code == 200

테스트코드 수행 결과 실패

-> 실패하는 이유는 ArticleSerializer() 에서 자식 모델 Comment를 참조하는데 prefetch를 하고 있지 않다.

실패 케이스 수정하기

기존 ArticleView의 쿼리셋에 comment 모델을 prefetch에 추가 해준다.

class ArticleView(viewsets.ModelViewSet):
    queryset = Article.objects.prefetch('comments').all()
    

테스트코드 수행 결과 성공~

=================================================================================== 1 passed in 7.85s ===================================================================================

profile
unit test, tdd, bdd, laravel, django, android native, vuejs, react, embedded linux, typescript

0개의 댓글