real-ESRGAN으로 사진 해상도 올리기(데이콘 SR 대회 우승자 방식으로)

Heesu Ahn·2022년 12월 22일
1

Computer Vision

목록 보기
2/2

저번 포스트에 이어 SR(Super Resolution) task를 위해 real-ESRGAN을 사용하는 방법을 소개하려고 한다. 이 모델은 basicsr 과 realesrgan 라이브러리를 사용하여 학습하기 때문에 기존 pytorch기반의 베이스라인 주피터노트북 파일로 진행하지 않고, 프로그램 작동하듯이 파이썬 파일로 진행하게 된다.

데이콘 대회 : AI양재허브 인공지능 오픈소스 경진대회

1등팀 코드 : [Private 1위 25.00327] RRDBNet + Geometric Self-Ensemble

--실험 환경--
Main
OS : Ubuntu 20.04.5 LTS
CPU : Intel(R) Xeon(R) Gold 6230R CPU @ 2.10GHz
GPU : A5000 * 4

Sub
Colab Pro+
(아마도 학습을 main에서 돌리고, 결과확인을 Colab으로 한 듯하다.)

우선 학습 환경은 최소 V100그래픽카드는 필요하다. 1등팀 환경은 A5000 x 4였으니 24GB x 4 = 96GB GPU 메모리 환경에서 905000iter 이상을 돌렸다. V100(32GB)에서 20만 iter를 돌리는데 3일이 걸렸으니 학습을 마치려면 A5000 x 4에서도 최소 4일이상 걸린다는 계산이 나온다.

1. 기본 setting

  git clone https://github.com/Jinwoo1126/Dacon_SR.git
  cd Dacon_SR

터미널 환경에서 git을 clone해온다. 더 하드코어하게 real-ESRGAN을 밑바닥에서 부터 컨트롤해보고 싶다면 여기 real-ESRGAN github를 참조바란다. 1등팀도 여기를 기준으로 실험을 설계해 나갔다.

# Install basicsr - https://github.com/xinntao/BasicSR
# We use BasicSR for both training and inference
!pip install basicsr
# facexlib and gfpgan are for face enhancement
!pip install facexlib
!pip install gdown
!pip install gfpgan
!pip install -r requirements.txt
!python setup.py develop

필요한 라이브러리와 실험환경을 설정한다.

  • basicsr : SR에 관련된 딥러닝 모델들이 라이브러리로 정리되어서 편하게 쓸 수 있다.
  • facexlib, gfpan : 얼굴 관련 화질 복원 라이브러리이다. realesrgan라이브러리는 setup.py에서 설치할 수 있도록 설정되어있는데 이것까지 전부 Xintao(Researcher at Tencent ARC Lab)님의 작품이다.
  • gdown : 구글드라이브 공유로 파일을 다운 받을 수 있는 라이브러리
## can download .pth file manually at https://drive.google.com/file/d/1piw_MOIE5bTH3-o9rmWqp3uIZoYcc5Wl/view
gdown 1piw_MOIE5bTH3-o9rmWqp3uIZoYcc5Wl    #get pth file 
mv net_g_905000.pth weights                #move pth file to weights folder

pretrained model을 다운 받는 방법이다. 여기서는 이 파일을 학습으로 만들어 내야한다. 905000은 돌린 iteration 숫자이다.

## can download compressed img file manually at https://drive.google.com/file/d/17Ui8Pc6NiPTd6dBsZJr8XS4l9S9F-wyL/view
gdown 17Ui8Pc6NiPTd6dBsZJr8XS4l9S9F-wyL    #get compressed img file(.zip)
mv open.zip inputs                         #move zip file to inpus folder
unzip inputs/open.zip -d inputs/           #unzip zip file

학습시킬 이미지 데이터를 다운 받는 방법인데, clone한 레포에 train이미지들만 수동으로 넣어줘도 된다.
폴더 구조는 다음과 같다.

  ./Dacon_SR
├── ...
├── inputs
│   ├── train # 이미지 데이터 넣어줘야 함
|   |   ├── hr
│   │   |   ├── 0000.png
│   │   |   ├── 0001.png
│   │   |   ├── 0002.png
|   |   |    ...
│   │   ├── lr
|   |       ├── 0000.png
|   |       ├── 0001.png
|   |       ├── 0002.png
|   |        ...
│   ├── test # 이미 있음
│       ├── 20000.png
│       ├── 20001.png
│       ├── ...
├── options
│   ├── ...
│   ├── finetune_realesrnet_x4plus_pairdata.yml # argparser 고칠 yaml 파일
├── realesrgan
│   ├── ...
│   ├── train.py
├── ...
├── results # inference한 파일이 저장됨, inference 시 생성
├── scripts
├── submission # results의 파일들 zip파일로 묶음
├── weights
│   ├── net_g_905000.pth # 다운받아서 넣어줘야 함
├── ...
├── inference_rrdbnetrot.py # inference 실행 파일

추가적으로 real-ESRGAN 제작자가 pretrain시켜놓은 모델이 있는데 다운받아서 적용하는걸 추천했다.

  wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights

Download pre-trained models: RealESRGAN_x4plus.pth

weights 폴더나 experiments/pretrained_models/를 만들어서 넣어놓으면 된다.

2. 학습 환경 설정

라이브러리를 불러와서 학습하는 방법이기 때문에 yaml파일을 수정해서 학습시키게된다. ./options/finetune_realesrnet_x4plus_pairdata.yml파일을 수정한다.

  # general settings
name: finetune_RealESRNetx4plus_400k_pairdata
model_type: RealESRNetModel
scale: 4 # 4배로 화질 상승!
num_gpu: auto
manual_seed: 0

# USM the ground-truth 
# 원본 사진 사용 여부이다.
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False

high_order_degradation: False # do not use the high-order degradation generation process

# dataset and data loader settings
datasets:

  train: # the 1st test dataset
    name: Dacon_train
    type: PairedImageDataset
    # 만약 학습 이미지 데이터 경로를 바꾸고 싶으면 여기서 수정한다.
    dataroot_gt: ./inputs/train/hr 
    dataroot_lq: ./inputs/train/lr
    io_backend:
      type: disk 
  	  # 파일 이름을 메모리에 올리는 방식을 선택하게 된다. 
      # 자세한건 ./realesrgan/data/realesrgan_paired_dataset.py파일 참조

    gt_size: 512 # lr의 이미지 사이즈
    use_hflip: true # Augmentation : hflip 수행
    use_rot: true # Augmentation : rotation 수행

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 4 
    # 가끔 num_worker때문에 오류나는 가상환경들이 있는데 0으로 설정하면 오류 해결
    batch_size_per_gpu: 4
    dataset_enlarge_ratio: 1
    prefetch_mode: ~

# network structures
network_g:
  type: RRDBNet
  num_in_ch: 3
  num_out_ch: 3
  num_feat: 64 # feature 갯수 높일수록 성능 향상 기대
  num_block: 23
  num_grow_ch: 32

network_d:
  type: UNetDiscriminatorSN
  num_in_ch: 3
  num_feat: 64
  skip_connection: True

# path
path:
  # use the pre-trained Real-ESRNet model
  # 제작자의 pre-trained 모델을 쓰려면 첫번째꺼를 쓰고 
  # 아무것도 없이 학습한다면 pretrain_network_g: 뒤를 비워주자. 
  #pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
  pretrain_network_g: ./weights/net_g_905000.pth
  param_key_g: params_ema
  strict_load_g: false
  resume_state: ~
  
# training settings
train:
  ema_decay: 0.999
  # Optimizer (Adam Optimizer 사용, Learning rate = 0.0004, beta값은 [0.9, 0.99])
  optim_g:
    type: Adam
    lr: !!float 1e-4
    weight_decay: 0
    betas: [0.9, 0.99]

  # Scheduler (200,000iter마다 gamma 값이 0.5씩 줄어들도록 decay 설정)
  scheduler:
    type: MultiStepLR
    milestones: [200000]
    gamma: 0.5

  total_iter: 999999999
  warmup_iter: -1  # no warm up
  
  # losses
  pixel_opt:
    type: MSELoss
    loss_weight: 1.0
    reduction: mean

# validation settings
val:
  val_freq: !!float 5e3 # 5000iter에 한번씩 validation한다.
  save_img: false
  pbar: False

  metrics:
    psnr:
      type: calculate_psnr # 평가방식 : PSNR (Peak Signal-to-Noise Ratio)
      crop_border: 4
      test_y_channel: true
      better: higher  # the higher, the better. Default: higher
    ssim:
      type: calculate_ssim
      crop_border: 4
      test_y_channel: true
      better: higher  # the higher, the better. Default: higher

# logging settings
logger:
  print_freq: 100
  # 100iter에 한번 진행상황이 프린트 된다. 
  save_checkpoint_freq: !!float 5e3
  # 모델 저장 주기를 수정할 수 있다. 
  # 모델은 ./experiments/[your experiments name(default='finetune_RealESRNetx4plus_400k_pairdata')]/models에 
  # 저장된다.
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29500

이 블록 안에도 작성하였지만 실험 상황을 수정할 방법은 다음과 같다.

Model hyper-parameters 조정

  • num_feat : network 전체(Generator 또는 Discriminator)에서 통용되는 feature map의 개수. 개수가 높아질수록 근소하게 성능은 좋아지지만 학습속도가 1.4배 떨어진다.
  • num_grow_ch : dense block 내부에서 증가되는 channel의 개수. num_feat과 마찬가지 결과.

Augmentation : hflip, rotation

  • use_hflip : 상하 반전(수평 뒤집기)
  • use_rot : 이미지 회전
    psnr에 큰 영향을 줌

기타

  • optimizer 의 lr
  • scheduler 조정 - 조정 주기 또는 scheduler 종류 등
  • pre-trained모델 변경 - realESRGAN 깃허브에는 다른 종류의 pre-trained모델도 있다.

3. 학습 시작

python realesrgan/train.py -opt options/finetune_realesrnet_x4plus_pairdata.yml

파일 경로만 다 맞고, 라이브러리가 잘 설치되었다면 잘 돌아갈 것이다. wandb를 설정하면 실시간으로 성능도 확인 가능하다.

모델은 ./experiments/[your experiments name(default='finetune_RealESRNetx4plus_400k_pairdata')]/models에 저장되므로 colab 이나 다른 컴퓨터로 계속 상황을 보면서 잘된 모델을 weights
폴더로 옮겨주면 된다.

이후 inference 하는 과정은 이전 포스트에 적어놨으니 참고바라고 좋은 화질의 사진을 얻길 바란다!!

profile
AI로 데이터를 말하는 개발자

2개의 댓글

comment-user-thumbnail
2023년 10월 27일

큰 도움이 됐습니다 :) 감사합니다 ! 혹시 성능 확인은 어떻게 해야 할까요? wandb 사용하지 않구요 !

1개의 답글