LeNet5 구현하기

Yelim Kim·2023년 7월 18일
0

Machine_Learning

목록 보기
44/44

Architecture Table

Code

import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv1_act = nn.Tanh()
        self.conv1_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.conv2_act = nn.Tanh()
        self.conv2_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.conv3_act = nn.Tanh()

        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.fc1_act = nn.Tanh()

        self.fc2 = nn.Linear(in_features=84, out_features=10)
        self.softmax = nn.Softmax(dim=1) # [-1, 10] 중에서 10

    def forward(self, x):
        x = self.conv1_act(self.conv1(x))
        x = self.conv1_pool(x)

        x = self.conv2_act(self.conv2(x))
        x = self.conv2_pool(x)

        x = self.conv3_act(self.conv3(x)) # (-1, 120, 1, 1) -> fc1 x

        x = x.view(x.shape[0], -1) # view = reshape, (B, 120) => flatten
        x = self.fc1_act(self.fc1(x))
        x = self.fc2(x)

        x = self.softmax(x)
        return x
profile
뜬금없지만 세계여행이 꿈입니다.

1개의 댓글

comment-user-thumbnail
2023년 7월 18일

정말 좋은 글 감사합니다!

답글 달기