ORM를 사용하는 프레임워크는 동일하게 N+1 문제가 발생 할 수 있다. 이 포스트글에서는 테스트코드에서 N+1문제를 예측하고 수정하는 방법을 공유하고자 한다.
각 프레임워크 마다 n+1 문제를 detection 하는 패키지들은 이미 휼륭한 개발자들 의해서 공유되고 있다. 장고에서는 https://github.com/jmcarp/nplusone 패키지가 있으며 이를 이용하고자 한다.
우선 패키지를 설치하다. 설치방법은 패키지의 readme를 참고한다.
아래 예제 코드는 리소스 기반으로 간단한 rest api를 작성한 샘플 코드이다.
모델을 2개가 있으면 Article - Comment 관계는 일대다이다.
from django.db import models
# Create your models here.
class Article(models.Model):
# user = models.ForeignKey(User, on_delete=models.CASCADE)
title = models.CharField(max_length=144)
subtitle = models.CharField(max_length=144, blank=True)
content = models.TextField()
created_at = models.DateTimeField(auto_now_add=True)
def __str__(self):
return '[{}] {}'.format(self.title, self.subtitle)
class Comment(models.Model):
article = models.ForeignKey(to=Article, related_name="comments", on_delete=models.CASCADE)
content = models.TextField()
일대다 관계 모델이다. Article - Comment
from django.shortcuts import render
# Create your views here.
from requests import Response
from rest_framework import viewsets
from .serializers import ArticleSerializer, CommentSerializer
from .models import Article, Comment
from rest_framework import permissions
class ArticleView(viewsets.ModelViewSet):
queryset = Article.objects.all()
serializer_class = ArticleSerializer
permission_classes = ()
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
def get_serializer(self, *args, **kwargs):
return super().get_serializer(*args, **kwargs)
def get_serializer_context(self):
context = super(ArticleView, self).get_serializer_context()
return context
class CommentView(viewsets.ModelViewSet):
serializer_class = CommentSerializer
permission_classes = ()
def get_queryset(self):
return Comment.objects.filter(article=self.kwargs['article_pk'])
def perform_create(self, serializer):
serializer.save()
rest 형식으로 리소스를 반환해준다.
from rest_framework import serializers
from .models import Article, Comment
from django.contrib.auth.models import User
class CommentSerializer(serializers.ModelSerializer):
class Meta:
model = Comment
fields = (
'id',
'article',
'content'
)
class ArticleSerializer(serializers.ModelSerializer):
comments = CommentSerializer(many=True, read_only=True)
class Meta:
model = Article
fields = (
'id',
'title',
'subtitle',
'content',
'created_at',
'comments',
)
read_only_fields = ('created_at',)
ArticleSerializer에서 자식 comments 모델들을 가져온다 (중요, prefetch 안 되어 있다.)
from django.urls import path, include
from rest_framework_nested import routers
from .views import ArticleView, CommentView
router = routers.SimpleRouter()
router.register(r'articles', ArticleView, basename='articles')
articles_router = routers.NestedSimpleRouter(router, r'articles', lookup='article')
articles_router.register(r'comments', CommentView, basename='article-comments')
urlpatterns = [
path('', include(router.urls)),
path('', include(articles_router.urls))
]
from unittest import mock
import pytest
from django.conf import settings
from django.test import Client
from api.models import Article, Comment
@pytest.fixture
def logger(monkeypatch):
mock_logger = mock.Mock()
monkeypatch.setattr(settings, 'NPLUSONE_LOGGER', mock_logger)
return mock_logger
def check_nplusone_problem(logger):
if len(logger.log.call_args_list) != 0: # 에러 문자가 포함되어 있다.
args = logger.log.call_args[0]
assert ("Potential n+1 query detected on" in args[1]) is False # prefetch 가 필요한 경우
assert ("Potential unnecessary eager load detected on" in args[1]) is False # 쓸데없이 prefetch를 한 경우
assert not logger.log.called # 정상적인 경우 호출 하면 안 된다.
def create_mock_article() -> Article:
"""
Article 목 데이터를 생성한다.
:return:
"""
# Given
name: str = "TEST_title"
description: str = "description"
# When
article: Article = Article()
article.title = name
article.subtitle = 'subtitle'
article.content = description
article.save()
comment: Comment = Comment()
comment.content = "AAA"
comment.article = article
comment.save()
return article
@pytest.mark.django_db
def test_route_article_GET(logger) -> None:
"""
AlbumLV 뷰의 url method GET 테스트 할 수 있다.
:return:
"""
# Given
create_mock_article()
create_mock_article()
create_mock_article()
client = Client()
# When
response = client.get('/api/articles/')
# Then
print(response.data)
check_nplusone_problem(logger)
assert response.status_code == 200
-> 실패하는 이유는 ArticleSerializer() 에서 자식 모델 Comment를 참조하는데 prefetch를 하고 있지 않다.
기존 ArticleView의 쿼리셋에 comment 모델을 prefetch에 추가 해준다.
class ArticleView(viewsets.ModelViewSet):
queryset = Article.objects.prefetch('comments').all()
=================================================================================== 1 passed in 7.85s ===================================================================================