3. nn & nn.functional

ingsol·2023년 1월 7일
0

PyTorch

목록 보기
3/8

1. torch.nn

  • 라이브러리
  • nn.Module을 이용하여 리팩토링 하기
    *여기서 리팩토링(Refactoring)이란 '결과의 변경 없이 코드의 구조를 재조장함'을 뜻한다. 주로 가독성을 높이고 유지보수를 편하게 하기 위해 사용되며, 버그를 없애거나 새로운 기능을 추가하는 것이 아님. 목적은 소프트웨어를 보다 이해하기 쉽고 수정하기 쉽도록 만드는 것.
    1) subclass(하위클래스)를 만들어 forward 단계에 대한 가중치, 절편, 메소드 등을 유지하는 클래스를 만들 수 있음
    2) attribute(속성, 'init')과 method(.parameters(), zero_grad())를 가지고 있음.

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

  • 변수 설명
    1) in_channels: 1x1x28x28 -> 얘를 필터 3개로 찍으면 ->
    2) out_channels: 1x3xHxW(H, W: 필터 크기에 따라 결정, 필터 크기는 kernel_size & stride & padding이 결정)
    3) bias: 필터 하나에 붙어있는 bias를 몇으로 할 것이냐
    4) 특징: weight값으 직접 설정x, 나와있는 변수들을 다 채워주면 convoltion을 하기 위한 weight가 자동 설정됨 --> convolution을 할 수 있는 하나의 layer가 생성됨
import torch.nn as nn
from torch.autograd import Variable

input = torch.ones(1,1,3,3)
input = Variable(input, requires_grad=True)
func = nn.Conv2d(1,1,3) # input, ouput, kernel_size
print(func.weight)
out = func(input)
print(out)
out.backward()
print(input.grad)

2. torch.nn.functional

  • PyTorch의 nn 클래스의 장점을 활용하여 코드를 더 간결하고 유연하게 만들 수 있음

  • 활성화, 손실 함수를 torch.nn.functional의 함수로 대체

  • 함수: 인스턴스화 시킬 필요 없이 사용 가능

    torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1 groups=1)

  • 변수 설명
    1) weight: 외부에서 만든 필터를 넣어줘야함
    2) 나머지는 동일

import torch.nn.functional as F
from torch.autograd import Variable

input = torch.ones(1,1,3,3)
fileter = torch.ones(1,1,3,3)
input = Varialbe(input, requires_grad=True)
fileter = Variable(filter)
out = F.conv2d(input, filter)
out.backward()
print(out_grad_fn) #ConvNdBackward object at~>

0개의 댓글