기존의 vanilla RNN이 지니던 gradient vanishing 또는 exploding 문제를 해결하고, time step이 먼 경우에도 필요로 하는 정보를 보다 효과적으로 처리하고 학습할 수 있도록 개선한 모델이다.
매 time step마다 변화하는 hidden state vector를 단기 기억을 담당하는 기억 소자로 볼 수 있는데, time step이 진행됨에 따라 단기 기억을 보다 길게 기억하도록 개선한 모델이라는 의미에서 'Long Short-Term Memory'라는 이름을 붙였다.
LSTM의 특징
기존 vanilla RNN에서는 현재 time step에서의 hidden state를 반영할 때 해당 time step에서의 입력과 전 time step에서 오는 hidden state를 입력으로 받았다.
그러나 LSTM에서는 이전 time step에서 두 개의 서로 다른 역할을 하는 입력과 현재 time step에서의 입력을 사용한다.
: cell state
: hidden state
두 state의 의미를 부여하자면 cell state인 가 전반적인 sequence에 관해 좀 더 완전한 정보를 가지고 있다고 볼 수 있고, hidden state인 는 를 한 번 더 가공해서 해당 time step에서 노출할 필요가 있는 정보만을 필터링해서 갖고 있다고 볼 수 있다.
이전에 포스팅했던 글에서는 각 gate별로 state를 도출하는 식을 작성했는데, 이를 통합적으로 바라보면 결국 하나의 파라미터 matrix인 를 가지고 각 gate에 해당되는 부분 행렬을 학습시키는 것으로 볼 수 있다.
위의 그림처럼 LSTM에서 학습되는 파라미터 가 존재할 때, 이를 각 gate별로 업데이트 해서 gate에서의 vector를 구하는 것으로 볼 수 있다.
Gate별 특징과 함께 후술하겠지만 , 를 이은 벡터와 파라미터 를 곱한 결과에 또는 를 취한 결과를 cell state인 또는 hidden state인 와 element-wise multiplication하는데, 이는 계산 결과를 최종적으로 얼만큼의 비율로 반영해줄지를 계산하는 과정으로 볼 수 있다.
또한 는 아직 로 나오기 전의 현재 시간에서의 임시 cell state로 이해하면 된다.
예를 들어, 어떤 한 gate에서 sigmoid를 가지고 구한 벡터의 한 원소 값이 0.3이고 이와 대응되는 cell state의 원소 값이 3이면, 두 원소를 곱하는 경우 cell state의 해당 원소에서 30%만 남겨서 0.9를 만든다고 볼 수 있다.
이처럼 LSTM에서의 gate는 cell state인 에서 얼만큼의 데이터를 가지도록 할지를 조절하는 역할을 한다.
주목해야 할 점은 현재 time step에서의 hidden state인 를 업데이트 하는 과정이 LSTM과 닮아있다는 것이다.
앞서 LSTM에서는 현재 time step의 cell state인 를 구할 때, 이전 cell state인 에 forget gate를 통과한 결과를 곱하고, 임시 현재 cell state인 에 input gate를 통과한 결과를 곱해서 더하는 과정을 거친다.
GRU에서는 update gate의 결과인 만 가지고 임시 hidden state인 에 그대로 곱하고 이전 hidden state인 에는 마치 forget gate를 적용하는 것처럼 를 곱해서 이를 더하는 식을 사용하는데, 이는 와 의 가중 평균을 구하는 것으로 볼 수 있다.
LSTM에서는 input gate와 forget gate의 독립적인 두 개의 gate 결과를 가지고 cell state를 업데이트했다면, GRU에서는 하나의 gate에서 hidden state를 연산하는 것을 볼 수 있다.
이로 인해 구조적으로 GRU는 LSTM에 비해 경량화된 모델로 볼 수 있는 것이다.
LSTM과 GRU의 Backpropagation
정보를 담는 주된 벡터인 LSTM의 cell state 또는 GRU에서의 hidden state를 업데이트 되는 과정이 기존 vanilla RNN처럼 동일한 를 계속 곱하는 연산이 아니라 매 time step마다 값이 다른 forget gate를 곱하고, 필요로 하는 정보를 곱셈 뿐만이 아니라 덧셈을 통해서 만들어 낼 수 있다는 특징으로 인해 gradient vanishing 또는 exploding 문제가 많이 사라지는 것으로 알려져있다.
기본적으로 덧셈 연산은 backpropagation을 수행할 때 gradient를 복사해주는 것처럼 작동하여 멀리 있는 time step에 관해서도 gradient를 큰 변형 없이 전달해줄 수 있어서 long term dependency 문제를 해결할 수 있다.
출처 1. 네이버 커넥트재단 부스트캠프 AI Tech NLP Track 주재걸 교수님 기초 강의