화요일, 9월 10, 2024
HomeDLRNN을 구현해보자!

RNN을 구현해보자!

RNN 이란?

CNN이나 Fc layer 신경망은 보통 feed forward라고해서 앞먹임. 즉 흐름이 앞으로 단방향인 신경망입니다.

피드포워드 신경망은 구조를 이해하기 쉽고, 다양하게 응용되지만 시계열 데이터에서는 젬병입니다.

그래서 순환신경망인 RNN이 나왔습니다.

피드포워드 구조의 문제점과 그걸 어떻게 해결했는지를 알아보면서 RNN을 공부해봅시다.

언어모델은 단어 나열에 확률을 부여합니다.
you say 다음에는 “goodbye”가 올 확률은 높게 “good die”는 낮게 출력하는 것이 언어모델입니다.

W1 ~ Wm 라는 m개의 단어로 된 문장이 있을 때 이때 단어가 W1 ~ Wm 라는 순서로 출현할 확률을 P(W1, …, Wm)으로 표현합니다.
이것을 수식화 할 수 있습니다.

식의 기호는

파이라는 것으로 총곱을 뜻합니다. 잘아시는 시그마는 총합이니까. 파이도 이해 가시죠?

BPTT(Backpropagation Through Time)

위와 같이 순환된 구조에서 backpropagation으로 학습하려면 문제점이 있습니다.

바로 데이터가 커질수록 메모리가 매우 커야한다는 것입니다.

그래서 위와같은 구조를 BPTT 시간방향으로의 오차역전파법 이라합니다.

이러한 문제를 해결하기 위해 Truncated BPTT가 나오게 됩니다.

Truncated BPTT

데이터가 큰 시계열일 경우 자릅니다. 그래서 Truncated 입니다.

위와 같이 잘라서 학습합니다. 마치 미니배치학습과 비슷합니다.

하지만 데이터를 이렇게 자르면 미니배치와는 다르게, 순서는 정확하게 주어져야 합니다. 달랐지만 통으로 1개의 데이터이기 때문이겠죠?

이를 바탕으로 구현에 적용해봅시다.

RNN 구현

Truncated BPTT의 원리대로 T개의 시계열 데이터를 한번에 처리하도록 구현해야 합니다.

먼저 기본적인 RNN을 구현합니다.

Python
class RNN:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b] #가중치 2개와 편향 1개 초기화
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None  # 역전파에 사용할 중간 데이터

    def forward(self, x, h_prev):
        Wx, Wh, b = self.params
        t = np.matmul(h_prev, Wh) + np.matmul(x, Wx) + b
        h_next = np.tanh(t)

        self.cache = (x, h_prev, h_next)
        return h_next
        
    def backward(self, dh_next):
        Wx, Wh, b = self.params
        x, h_prev, h_next = self.cache

        dt = dh_next * (1 - h_next ** 2)  # tanh 미분
        db = np.sum(dt, axis=0)
        dWh = np.dot(h_prev.T, dt)  # shape: (H, N) x (N, H) = (H, H)
        dh_prev = np.dot(dt, Wh.T)  # shape: (N, H) x (H, H) = (N, H)
        dWx = np.dot(x.T, dt)  # shape: (D, N) x (N, H) = (D, H)
        dx = np.dot(dt, Wx.T)  # shape: (N, H) x (H, D) = (N, D)

        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db
      

        return dx, dh_prev

tanh 함수의 미분

ShellScript
dt = dh_next * (1 - h_next ** 2)  # tanh 미분

tanh 함수의 미분값은 아래와 같습니다.

그렇게 때문에 이전 노드의 미분값(dh_next) 과 체인룰로 곱해줘서 위와 같은 값이 나옵니다.

Time RNN 계층구현

본격적인 구현입니다. RNN계층이 T개 만큼 있는 것이 Time RNN입니다.

계층이 StateFul 상태인지 state None 상태인지가 중요한 키포인트입니다.

Python
class TimeRNN:
    def __init__(self, Wx, Wh, b, stateful=False):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.layers = None  # RNN 계층을 리스트로 저장
        
        self.h, self.dh = None, None
        self.stateful = stateful
        
    def set_state(self, h):
        '''hidden state(h)를 설정하는 메서드'''
        self.h = h
    
    def reset_state(self):
        '''hidden state(h)를 초기화하는 메서드'''
        self.h = None
        
    def forward(self, xs):
        Wx, Wh, b = self.params
        N, T, D = xs.shape  # N(batch), T(time steps), D(input size)
        D, H = Wx.shape
        
        self.layers = []
        hs = np.empty((N, T, H), dtype='f')
        
        if not self.stateful or self.h is None:
            self.h = np.zeros((N, H), dtype='f')
            
        for t in range(T):
            layer = RNN(*self.params)
            self.h = layer.forward(xs[:, t, :], self.h)
            hs[:, t, :] = self.h
            self.layers.append(layer)
            
        return hs
    
    def backward(self, dhs):
        Wx, Wh, b = self.params
        N, T, H = dhs.shape
        D, H = Wx.shape
        
        dxs = np.empty((N, T, D), dtype='f')
        dh = 0
        grads = [0, 0, 0]
        for t in reversed(range(T)):
            layer = self.layers[t]
            dx, dh = layer.backward(dhs[:, t, :] + dh)  # 합산된 기울기
            dxs[:, t, :] = dx
            
            for i, grad in enumerate(layer.grads):
                grads[i] += grad
                
        for i, grad in enumerate(grads):
            self.grads[i][...] = grad
        self.dh = dh
        
        return dxs

다음 포스팅에서는 RNN을 이용해 언어모델을 만드는 RNNLM(language model)을 구현해보겠습니다.

RELATED ARTICLES

Leave a reply

Please enter your comment!
Please enter your name here

Most Popular

Recent Comments