본문 바로가기
AI · 인공지능/알기쉬운 AI

[알기쉬운 AI - 30] LSTM (Long Short-Term Memory)

by 두우우부 2020. 5. 29.
반응형

 

시작하기

이전에는 시계열로 가변 데이터를 처리할 수 있는 리커런트 뉴럴 네트워크(Recurrent Neural Network)에 대해 알아봤습니다. RNN이 긴 데이터를 처리할 때, 계산량이 폭발하기 때문에 기억 중단형 통시적 역전파(Truncated  Back propagation Through time)는 단기기억 정보만을 처리합니다. 그러나 역시, 이 경우에는 AI의 정확도에 한계가 있습니다. 그래서 이 단순 RNN(Simple Recurrent NN)의 장기 의존성 문제를 해결하는 구조를 가진 장 · 단기 기억 장치(Long Short-Term Memory)라는 모델이 나타났고, 이것이 현재 RNN의 주류를 이루고 있습니다. 

 

단순 RNN의 장기 의존성 문제

LSTM의 설명에 들어가기 전에 간단한 리커런트 신경망의 장기 의존성 문제에 대하여 복습해 봅시다. 단순 RNN은 초능력자가 아니기 때문에, "나의", "영희가", "사랑해 마지않는", "케이크는" 이라는 문장의 마지막 단어를 맞출 수 없습니다. 하지만 기억 범위가 더 넓어져서 "영희가 순자와 너무나 행복한 표정으로 생크림 케이크를 먹는 모습을 봤습니다"라는 글의 바로 뒤에, "영희가", "사랑해 마지않는", "케이크"가 오면 어떨까요. 이번에는 '생크림'이라는 말이 바로 떠오르네요.

 

우리 인간은 이처럼 바로 직전의 정보뿐만 아니라 필요에 따라서 더 이전의 문장 정보도 이용하고 있는 것입니다. 이러한 장기 기억이 RNN에서도 구조적으로 불가능한 것은 아닙니다. 그러나 "영희가", "순자", "매우", "행복해", "생크림", "먹고 있다", "모습", "봤습니다"처럼 정보량이 많아지게 되면, 어떤 가중치로 어떻게 연관이 될지 상당히 복잡해집니다. 수십 단계라면 대응할 수 있지만, 100단계 이상 된다면 계산량이 폭발해 버립니다. 이것이 단순 RNN의 장기 의존성 문제입니다. 

 

이 문제를 해결하기 위해 등장한 것이 LSTM입니다(Long Short-Term Memory). 그림 1과 같이 단순 RNN이 단기 기억만 이용하는 반면, LSTM은 장기 의존(long-term dependencies)을 학습할 수 있도록 개량한 모델입니다. 단순 RNN은 계산량이 폭발하는데 LSTM은 괜찮다니... 어떻게 이런 일이 가능한 것일까요?

 

그림 1 : 신경망의 기억 영역

 

RNN 구조

단순 RNN은 그림 2와 같은 리커런트(순환) 구조를 가지고 있습니다. 이전 셀의 출력(Recurerent)과 입력(Input)이 합쳐져 출력(output)이 나오는 간단한 모델이지만, 전에 설명을 생략했던 tanh이라는 묘한 녀석이 있습니다. 이것은 하이볼릭 탄젠트라는 함수로 통계에 자주 등장하는 로지스틱 시그모이드 함수입니다.

 

그림 2 : RNN의 리커런트 구조

 

하이볼릭의 의미는 '쌍곡선'입니다. 시그모이드라는 말은 이전의 '회식 참석 예시' 때 시그모이드 뉴런이란 말로 한번 등장했었지요. 퍼셉트론의 출력이 1 또는 0의 두 값인데 반해, 시그모이드 뉴런은 0~1까지의 실수 모델이었습니다. 그리고, 로지스틱이라는 말도 앞에서 영희의 '정시 귀가를 예측하는 로지스틱 회귀'에서 등장했습니다. "발생 확률을 예측하여 확률에 따라 Yes/No로 분류하는 것"이었다는 것, 기억나시나요?

 

 

[알기쉬운 AI - 26] 분류 (Classification)

이번에는 감독 학습의 '분류' 중에서 '로지스틱 회귀'와 'K근접법'에 대해서 설명합니다. 딥러닝을 빨리 배우고 싶은데, 기계 학습 알고리즘에 계속 발이 묶여있는 것 같네요... 그러나 딥러닝을 �

doooob.tistory.com

 

'[26] 분류 (Classification)'에 나왔던 로지스틱 함수의 그림을 기억해 보십시오. 시그모이드 함수도 로지스틱 함수도 같은 S자형 함수로, X의 값을 0~1의 값으로 변환하는 함수입니다. 한편, tanh는 하이볼릭(쌍곡선)이라는 말이 붙어있듯이, 0~1 대신 -1~1 사이의 값으로 변환합니다(그림 3). 

 

시그모이드와 달리 tanh의 출력은 정과 부의 값을 가질 수 있으므로 상태의 증감이 가능합니다. 따라서 tanh 셀의 반복 연결에 사용되어, 내부에서 추가되는 후보 값을 결정하는데 유용합니다. 2차 미분이 제로가 되기 전에 장기간 값을 유지할 수 있으므로 기울기 손실 문제를 해결하기 좋은 함수입니다. 2차 미분, 경사 손실... 이런 건 너무 어려우니까, 그냥, 정보를 이용하기 좋은 상태로 변환해 주는 것이라고 생각해 주세요. 그림 2로 말하면, 이전 셀로부터의 리커런트 정보(기억 정보)를 그대로 흘려보내는 것이 아니라, tanh이 요점을 잘 정리해준다는 느낌입니다.

 

그림 3 : Tanh 쌍곡선 함수

 

LSTM의 구조

이어서 (그림 4) LSTM의 리커런트 구조를 살펴봅시다. 오호, 훨씬 복잡해졌군요. 하지만 이것도 정상적인 LSTM입니다.

 

그림 4 : LSTM의 리커런트 구조

 

(1) 전 셀의 출력에 기억 라인이 추가됨(ht-1과 Ct-1)

단순 RNN에서는 하나였던 이전 셀의 정보 전달이, 출력(Recurrent) 외에 기억(Memory)이 추가되어 2라인으로 되어 있네요. 이것은, 대략적으로 말해 Recurrent 쪽이 RNN과 같은 단기 기억이고, Memory가 장기 기억이라고 생각하시면 됩니다. LSTM은 그 이름과도 같이, 단기와 장기를 연관시키면서 각각 다른 라인에서 기억을 보존하고 있는 것입니다.

 

(2) 전 셀의 출력(Recurrent)과 입력의 합류(ht-1과 Xt) 

이전 셀의 출력 ht-1(단기 기억)과 지금 셀의 입력 Xt가 합류합니다. 합류된 신호는 4개의 라인에 분기(동일 정보 복사)됩니다. 이 합류 결과는 '영희가 좋아하는'이라는 단기 기억에 입력값 '케이크'가 더해진 것입니다(이것은 이전 RNN과 같습니다). 

 

(3) 망각 게이트(ft의 출력)

가장 윗 라인은 망각 게이트입니다. 이것은 이전 셀에서의 장기 기억 하나하나에 대해 σ(시그모이드 함수)에서 나온 0~1 사이의 값 ft로 정보의 취사선택을 하는 것입니다. 1은 모든 남기고 0은 전부 버려야 합니다. 단기 기억 ht-1와 입력 Xt로 '영희가 사랑해 마지않는 케이크'까지 인식한 시점(t)에서 장기 기억 속의 '순자'는 중요하지 않다고 판단했을 때, σ의 출력 ft는 0 근처의 값이 되어, 이 기억을 망각합니다. 한편, 「생크림」이라는 정보는 중요한 것 같아서 ft는 1로 그대로 남아 있습니다.  

 

RNN이 과거의 모든 정보를 이용하려고 하면 계산량이 폭발하겠지만, 망각 게이트에 의해 원치 않는 정보를 버림으로써 폭발을 방지합니다(필요 없는 정보를 계속 잊어버리는 것은 인간과 동일하군요). 

 

※ 3개의 게이트와 시그모이드 함수 σ
LSTM은 망각 게이트(forget gate)와 입력 게이트(input gate), 출력 게이트(output gate)의 3개의 게이트가 있습니다. 게이트라고 하면 자신의 신호의 출입구가 있는 이미지가 떠오르시겠지만 여기에서는 조금 다릅니다. 시그모이드 함수 σ에 의해 흘러나오는 신호 게이트의 개폐를 실시하고 있는 제어문입니다. 1은 열림, 0이 닫힘, 0.5는 반열림의 게이트 신호 가중치 컨트롤을 하고 있는 것입니다.

 

(4) 입력 게이트(Ct'와 it)

단기 기억 ht-1과 입력 Xt로 합산된 입력 데이터를 장기 보존용으로 변환한 후 어떤 신호를 어느 정도의 무게로 장기 기억에 저장할지 제어합니다. 이것은 두 단계로 처리됩니다. 

 

① tanh에 의한 변환(Ct'를 출력)

들어온 정보를 그대로 흘려보내는 것이 아니라 요점을 맞춘 단적인 형태로 만드는 쪽이 정보량을 줄일 수 있고, 사용하기 좋습니다. 아까 tanh은 내부에서 추가되는 후보 값을 결정하는 데 유용한 함수라고 했습니다. 예를 들어, '사랑해 마지않는'을 '좋아하는'이라는 후보로 바꾸는 식입니다. 이렇게 간단히 변환되어 Ct'가 출력됩니다.  

 

② 입력 게이트(it)에 의한 선별

지난번 언급했듯, LSTM은 통시적 오차 역전파(Back propagation Through time)에 의해 가중치를 조정합니다. 보통의 오차 역전파는 입력 Xt의 weight 조절이지만, 통시적 오차 역전파는 이외에도 이전 셀에서의 단기 기억 ht-1의 정보에 영향을 받습니다. 따라서 ht-1에서 들어오는 무관한 정보에 의해 가중치가 잘못 업데이트되는 것을 방지하기 위해 입력 게이트가 필요한 오차 신호만 제대로 전달하도록 제어하고 있습니다.  

 

ht-1 + Xt로 만들어진 '영희가 사랑해 마지않는 케이크'라는 정보 중에서 입력 게이트 σ(시그모이드 함수)가 남겨둘 것과 흘려보낼 것을 선별합니다.

 

(5) 출력 게이트(ot를 출력)

ht는 단기 기억의 출력입니다. 위와 같은 과정을 통해 장기 기억에 단기 기억이 더해져 선별된 값(장기 기억의 출력 Ct)에서 단기 기억에 관한 부분만 출력합니다. 여기서도 아까와 마찬가지로 2단계로 처리됩니다. 

 

① tanh에 의한 변환

tanh의 입력은 이전 셀에서의 장기 기억 Ct-1에 입력 Xt를 변환한 단기 기억 Ct'을 더한 것입니다. 각각 망각 게이트 및 입력 게이트로 취사선택되고 있습니다. 이것을 그대로 장기 기억으로 출력하는 것이 Ct이지만, 거기에 포함된 단기 기억 부분도 장기 기억과 함께 포함시킴으로 인해 단기 기억만 있을 때보다 이용하기 쉽게 변환할 수 있습니다.  

 

예를 들어, 단기 기억이 '나의 그녀가 좋아하는 케이크는'이었다고 합시다. 이 경우 장기 기억에 나의 그녀가 영희라는 중요한 요소가 있다면, 단기 기억을 좀 더 명확하게 "영희가 좋아하는 케이크"로 변환해 버리는 이미지입니다. 

 

② 단기 기억의 취사선택

입력 게이트가 스스로 셀을 보호했듯이 출력 게이트도 다음 셀에 대한 나쁜 정보의 전파를 방지합니다. 다음 셀을 활성화하기 위한 가중치 ht를 업데이트할 때, 관련 정보를 흘려 나쁜 영향을 주지 않도록 해야 합니다. 출력 게이트 σ(시그모이드 함수)에 의해 0~1의 범위에서 Ot가 출력되고, 단기 기억 출력 ht에 필요한 신호만 제대로 전달하도록 제어하고 있습니다. 

 

이번에는 입력 게이트에서 "나의"라는 말이 이미 잘렸기 때문에 출력 게이트에서는 특별히 자를 말이 없습니다. 이 입력에서 출력으로도 이중으로 게이트 체크하여 관련 정보가 흐르지 않도록 철저히 하고 있습니다. 지금까지의 처리 결과 이 셀에서 '영희가 좋아하는 케이크'라는 정보가 ht에 출력된 것입니다. 

 

정리

이번에는 리커런트 뉴럴 네트워크(단순 RNN)의 단점을 보완한, LSTM의 구조에 대해 설명했습니다. 이전 셀의 출력이 단기 기억과 장기 기억으로 나누어져 있고, 정보가 폭발하지 않도록 원치 않는 정보는 망각 게이트에서 삭제하여 불필요한 정보로 잘못된 가중치 갱신을 하지 않도록 입력 게이트와 출력 게이트에서 취사선택을 하는 것입니다. tanh에 의해 정보를 그대로 흘리는 것이 아니라 이용하기 쉬운 형태로 변환하는 것 등을 알 수 있었습니다.

 

또한, 여기에서는 이해하기 쉽도록 문장을 사용하여 설명하고 있지만, 실제로 RNN이 어떤 작업을 수행하는지는 블랙박스이며, 학습 정도에 따라 달라집니다. LSTM의 각 파트에서 이런 방식으로 처리하고 있다는 것만 이해해 주시면 OK입니다.

반응형