Unet 논문 네트워크 구현해보기

메롱하는메로나·2024년 7월 9일
0

논문

목록 보기
1/2

Unet 네트워크 구현

Unet 구조

U-Net은 그림과 같이 네트워크가 'U'자 형태로 구성되어 있기 때문에 U-Net이라고 이름이 붙었다.

크게 Contracting path와 Expanding path로 구성되어 있다.

Contracting Path

  • input
  • 3x3 convolution (파란 화살표) 두번 시행
  • 2x2 max pooling (stride=2, 빨간 화살표)
  • Down-sampling
  • 채널 수 2배로 증가

Expanding Path

  • 2x2 up-convolution (초록 화살표)
  • 3x3 convolution 두번 시행
  • Up-sampling
  • 채널 수 반으로 감소
  • ReLU 활성화 함수 사용
  • Up-Convoution된 Feature Map은 Contracting Path의 가장자리를 crop하여 크기를 맞춘 후 concatenation(결합)함
  • 마지막 1x1 convolution 연산



여기까지 간단하게 UNet 구조에 대해서 살펴보았다. 이제 UNet 네트워크를 직접 구현해보자.

UNet 네트워크 코드

import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms, datasets

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr
        

        #Contracting path
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)


        #Expanding path
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec4_2 = CBR2d(in_channels=2*512, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec3_2 = CBR2d(in_channels=2*256, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec2_2 = CBR2d(in_channels=2*128, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec1_2 = CBR2d(in_channels=2*64, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=2,
                                          kernel_size=1, stride=1, padding=0, bias=True)


    def crop_tensor(self, tensor, target_tensor):
        target_size = target_tensor.size(2)
        tensor_size = tensor.size(2)
        delta = (tensor_size - target_size) // 2
        return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]    

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        crop4 = self.crop_tensor(enc4_2, unpool4)
        cat4 = torch.cat((unpool4, crop4), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        crop3 = self.crop_tensor(enc3_2, unpool3)
        cat3 = torch.cat((unpool3, crop3), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        crop2 = self.crop_tensor(enc2_2, unpool2)
        cat2 = torch.cat((unpool2, crop2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        crop1 = self.crop_tensor(enc1_2, unpool1)
        cat1 = torch.cat((unpool1, crop1), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x

이 코드를 순서대로 살펴보자.

UNet 네트워크 구현 - (1) 라이브러리 import

import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms, datasets

필요 라이브러리들을 import 해준다.


UNet 네트워크 구현 - (2) UNet 클래스 생성

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
  • UNet 클래스는 PyTorch의 nn.Module클래스를 상속받음
  • UNet 클래스는 nn.Moudule 클래스의 모든 메서드와 속성을 사용할 수 있게 됨

UNet 네트워크 구현 - (3) CBR2d 함수 생성

def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr
  • C(Convolution), B(BatchNormalization), R(ReLU)
  • 입력 채널과 출력 채널을 받아 그에 맞는 Convolution Layer, Batch Normalization Layer, ReLU Activation Function을 포함하는 블록을 생성
  • 신경망의 기본 블록을 생성하는 것
  • nn.Sequential을 통해 하나의 모듈로 반환
  • kernel_size : 커널(필터)의 크기
  • stride : 커널이 이동하는 간격
  • padding : 입력에 추가되는 패딩

UNet 네트워크 구현 - (4) Contracting Path

        #Contracting path
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

  • self.enc1_1 : 첫번째 stage의 첫번째 step(파란화살표)
  • self.enc1_2 : 첫번째 stage의 두번째 step(파란화살표)
  • self.pool1 : 빨간색 pooling 화살표
  • 2x2 max pooling이기 때문에 kernel_size=2가 된다
  • 각 stage의 첫번째 파란 화살표를 지나면서 output channel이 2배씩 증가
  • 이런식으로 마지막 stage까지 계속 진행

UNet 네트워크 구현 - (5) Expanding Path

        #Expanding path
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec4_2 = CBR2d(in_channels=2*512, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec3_2 = CBR2d(in_channels=2*256, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec2_2 = CBR2d(in_channels=2*128, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec1_2 = CBR2d(in_channels=2*64, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=2,
                                          kernel_size=1, stride=1, padding=0, bias=True)

  • 마찬가지로 각 stage의 두번째 화살표(dec(n)_1)을 지나면서 output channel이 절반이 된다.
  • 주의할 점은 첫번째 파란색 화살표를 지날때 input_size가 unpool된 output_size의 두배라는 것
  • 이는 Contracting Path에서 crop되어 붙는 부분이 존재하기 때문
  • 따라서 self.dec3_2 = CBR2d(in_channels=2*256, out_channels=256) 이런식으로 input channel을 두배로 구성해 주어야 한다.

  • self.fc = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True) : 마지막 부분은 1x1 convolution으로 최종 출력 채널을 2로 만듦

UNet 네트워크 구현 - (6) crop_tensor 함수 생성

def crop_tensor(self, tensor, target_tensor):
        target_size = target_tensor.size(2)
        tensor_size = tensor.size(2)
        delta = (tensor_size - target_size) // 2
        return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]
  • 각 stage에서 crop하는 부분을 구현
  • Contracting Path에서 사용된 convolution 연산과 pooling 연산으로 인해 featrue map의 크기가 점점 작아지기 때문에 맞추는 과정
  • Contracting Path의 feature map 가장자리 부분을 잘라내는 방식

UNet 네트워크 구현 - (7) 각 단계 이어주기

def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool3(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x
  • Contracting Path
  • 입력 이미지를 받아 예측한 출력을 생성하는 과정을 단계별로 정의 및 연결
  • enc1_1, enc1_2: 첫 번째 인코딩 레이어로, 입력 이미지 xenc1_1enc1_2를 통해 처리
  • pool1: 첫 번째 풀링 레이어로, enc1_2의 출력 특성 맵을 축소
  • enc5_1: 마지막 인코딩 레이어로, 더 이상 풀링하지 않음

  • Expanding Path
  • dec5_1: 첫 번째 디코딩 레이어로, 인코딩의 마지막 레이어 enc5_1을 처리
  • unpool4: pool4에 대응되는(이미지상 반대편) 네 번째 업샘플링 레이어로, dec5_1의 출력을 두 배로 확대
  • cat4: 인코딩 경로의 네 번째 레이어(enc4_2)와 디코딩 경로의 출력(unpool4)을 결합
  • torch.cat : 텐서를 연결(concatenate)하는 함수로, Expanding Path에서 unpool된 특성 맵과 Contracting Path의 특성 맵을 연결해줌
  • self.fc: 마지막 컨볼루션 레이어로, 64채널의 특성 맵을 2채널(클래스 수)에 해당하는 출력으로 변환



Crop된 부분과 Unpool된 부분을 결합하는 이유

  • Contracting Path에서 이미지가 축소될 때, 세부 정보가 손실될 가능성이 존재한다. 따라서 Expanding Path에서 Contracting Path의 특성 맵을 결합함으로써 세부 정보를 어느정도 보존하고, 좀 더 정교한 예측이 가능해진다.
  • Image Segmentation을 할 때, 경계 정보가 중요하다. Contracting Path의 특성 맵을 결합함으로써 경계 정보가 풍부해지고 더 정확한 Segmentation 결과를 얻을 수 있다.
profile
올 때 메로나🍧

0개의 댓글