lstm 15. 장단기(LSTM) 메모리와 pytorch train 알아보기

행동하는 개발자·2023년 3월 28일
0

RNN

목록 보기
14/14

1. 바닐라 RNN의 한계

앞에서 바닐라 RNN은 출력 결과가 이전의 계산 결과에 의존한다는 것을 언급한 바 있습니다. 하지만 바닐라 RNN은 비교적 짧은 시퀀스에 대해서만 효과를 보이는 단점이 있습니다. 바닐라 RNN의 시점이 길어질 수록 앞의 정보가 뒤로 충분히 전달되지 못하는 현상이 발생합니다.

색이 연해질수록 첫번째 입력값인 X1의 정보량이 옅어지는 것을 확인할 수 있습니다. 이를 장기 의존성 문제라고 합니다.

2. LSTM

위의 그림은 LSTM의 전체적인 내부의 모습을 보여줍니다. 전통적인 RNN의 단점을 보완한 RNN의 일종을 LSTM이라고 합니다. 은닉층의 메모리 셀에 입력, 망각, 출력 게이트를 추가하여 불필요한 기억을 지우고 기억해야 할 것들을 정합니다. 요약하면 LSTM은 은닉상태를 계산하는 식이 전통적인 RNN보다 조금 더 복잡해졌으며 셀 상태(cell state)라는 값을 추가했습니다.

셀 상태는 위의 그림에서 왼쪽에서 오른쪽으로 가는 굵은 선입니다. 셀 상태 또한 이전에 배운 은닉 상태처럼 이전 시점의 셀 상태가 다음 시점의 셀 상태를 구하기 위한 입력으로서 사용된다.

은닉 상태의 값과 셀 상태의 값을 구하기 위해서 새로 추가된 세개의 게이트를 사용한다. 각 게이트는 삭제, 입력, 출력 게이트이고 공통적으로 시그모이드 함수가 존재한다.

3. train

RNN

import torch
import torch.nn as nn

# Define the RNN model
class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded)
        out = self.fc(hidden.squeeze(0))
        out = torch.sigmoid(out)
        return out

LSTM

import torch
import torch.nn as nn

# Define the RNN model
class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        embedded = self.embedding(x)
        output, (hidden, cell) = self.rnn(embedded)
        out = self.fc(hidden.squeeze(0))
        out = torch.sigmoid(out)
        return out

train 하기 위한 예제 코드는 다음과 같다.

GRU

LSTM의 대안으로 개발된 더 단순한 유형의 RNN이다. LSTM과 마찬가지로 GRU는 게이트를 사용하여 정보의 흐름을 제어하지만 업데이트 게이트와 재설정 게이트라는 두 개의 게이트만 있다. 업데이트 게이트는 유지할 이전 은닉 상태의 양과 추가할 새 입력의 양을 결정하는 반면, 재설정 게이트는 잊을 이전 은닉 상태의 양을 결정한다.

GRU는 LSTM보다 더 빠르고 더 적은 매개변수를 필요로 하므로 더 쉽게 훈련하고 더 효율적으로 계산할 수 있다.

출처: https://wikidocs.net/22888

profile
끊임없이 뭔가를 남기는 사람

0개의 댓글