본문 바로가기
Boostcamp AI Tech/[week 6-14] LEVEL2

[week 7 - day 1,2] Transformer

by newnu 2021. 9. 14.
반응형

Transformer

  • LSTM, GRU 기반 seq2seq 모델이 성능 개선한 모델

  • Attention is all you need

  • No more RNN or CNN modules

RNN

Forward RNN

  • 왼쪽의 정보만 포함

Backward RNN

  • 오른쪽의 정보만 포함
  • Forward RNN과는 별개의 parameter 사용

Bi-directional RNN

  • 항상 모든 단어를 포함할 수 있도록 두 모델을 병렬적으로 만들고 특정 timestep의 hidden state vector concat하여 2배의 차원을 가지는 벡터 생성

Transformer - Swlf Attention module

입력 sequence의 각 단어별로 sequence 전체 내용을 반영한 encoding vector가 output

  • 각 input vector가 seq2seq 모델에서 decoder의 hidden state vector처럼 사용

  • 모든 input vector와 내적 -> 유사도 ->softmax로 얻어진 가중치 -> 입력 벡터에 대한 가중 평균 -> i에 대한 encoding vector

  • 같은 입력 벡터가 각각 다른 재료로 사용

    • query : 어떤 벡터를 선별적으로 가져올지 기준이 되는 벡터 (seq2seq 모델에서 decoder의 hidden state vector 역할)
    • key : query 벡터와 내적이 되는 재료 벡터 - 어느 key 벡터가 높은 유사도를 가지고 있는지 결정
    • value : 유사도를 구한후 softmax를 위한 후 가중평균이 계산되는 재료 벡터
  • 각 역할에 따라 다른 형태로 변환할 수 있도록 하는 linear tranformation matrix 따로 정의

EX) "I go home" 이라는 입력 sequence에서 I 의 예를 보면

  • "I" 에 해당하는 입력 벡터가 Wq matrix에 의해 query 벡터로 변환
  • Wk, Wv 에 의해 각각 key, value 벡터로 변환
    • key ,value 벡터의 개수는 동일해야 한다
  • q, k 내적값 구한 후 softmax 통과시켜 합이 1인 값 얻음
    -> query의 벡터와 다른 key와의 내적값보다 작을 수도 있게 됨
  • 가중치는 value vector에 부여되어 가중 평균 계산 -> I 에 sequence 전체의 벡터들을 고려해 얻은 벡터

Transformer : Scaled Dot-Product Attention

  • Inputs : query 벡터, (key,value) 벡터의 쌍
  • q와 k는 같은 dimension (내적 연산하기 위해)
  • value의 dimension은 같지 않아도 됨 (상수배해서 가중 평균 내는 것)
  • softmax(QKT)를 해줄 때 Q * KT를 상수 sqrt(dk)로 나누어 준 후 softmax 적용한다
    • 차원이 크면 분산은 각 분산의 합으로 분산이 커지게 된다
    • 분산이 클수록 softmax의 확률분포가 큰 값에 몰리는 현상 -> back propagation 시 gradient vanishing 발생 위험
    • 분산이 작으면 비교적 고르게 분포
    • sqrt(dk)로 나눠주어 분산을 1로 유지하여 학습 안정화

Transformer : Multi-head Attention

  • self attention 모듈을 유연하게 확장

  • 동일한 q,k,v 에 대해서 병렬적으로 동시에 attention 여러 버전 수행

  • 여러 버전의 q,k,v 존재

  • 각 i 번쩨 head별로 각각의 Wq, Wk, Wv 존재

  • 서로 다른 버전의 attention 개수만큼 동일한 쿼리 벡터에 대한 인코딩벡터 concat

  • ex) 특정한 query word에 대해 서로 다른 기준으로 여러 차원에서 특정 정보 뽑아야하는 경우

    • "I went to the school" , "I studied hard", "I came back home", "I took the rest"
    • I 라는 주체가 한 행동 중심 : went, study, came back, took
    • I 라는 주체가 존재하는 장소의 변화 : school, home
    • 서로 다른 정보 병렬적으로 뽑고 합치기
  • 각각의 head에서 얻어진 정보를 concat

  • linear transformation 통해 원하는 dimension으로 최종 vector 얻음

attention 모듈의 계산량

forward propagation시 모든 정보 저장해야 back propagation시 사용 가능
Q_KT
Self attention : (n_d) * (d_n) -> n^2 * d
행렬 연산의 병렬화로 한번에 계산 가능
RNN : timestep 개수 n , 각 timestep에서 계산되는 양 (d*d) (d: dimension). --> n_ d^2

d는 hyperparameter로 임의 지정가능
n은 seq 길이(주어진 입력에 따라 가변적) 이므로 n에 의해 영향을 받음

Complecxity per Layer 메모리 요구량

Self attention : O(n^2 * d)
RNN : O(n * d^2)

Sequential Operations

Self attention : O(1) 동시에 수행 가능
RNN : O(n) 각 timestep까지 계산 하고 다음 계산 진행

Maximum Path Length (long term dependancy와 관련)

Self attention : O(1) 동일한 K,V로 보기 떄문에 한번에 가져올 수 있음
RNN : O(n) timestep 차이만큼 layer 통과 (맨 뒤와 맨 앞 차이 n)

Transformer : Block - Based Model

스크린샷 2021-09-15 오후 4 22 31 - multi head attention을 핵심 모듈로 하여 추가적인 후처리 진행하여 하나의 모듈 구성 - Multihead attention - Add & Norm - Feed Forward - Add & Norm - Feed Forward : Fully connected layer 통과하여 각 word의 인코딩 벡터 변환

ADD - Residual Connection

  • 깊은 layer의 neuronnet을 만들때 grandient vanishing 문제 해결 학습 안정화, 레이어를 쌓아감에 따라 높은 성능

  • 각 input vector 와 multi head attention의 output vector의 합 - input vector에 대한 인코딩 벡터 얻음

  • 입력값 대비 만들고자 하는 벡터의 차이값만을 multi head attention 모듈에서 만들어줘야함

    • gradient vanishing 해결, 학습 안정화
  • residual connection 적용하기 위해 입력벡터와 encoding 출력벡터 dimesion 동일하도록 유지

    Layer Normalization

  • Normalization : batch norm, layer norm, instance norm, group norm

  • 주어진 다수의 샘플들에 대해서 값들의 평균을 0, 분산을 1로 만들어 준 후 원하는 평균과 분산을 주입할 수 있도록 하는 선형변환

  • ex) 2x+3 -> 각각의 값들이 x에 들어가서 평균이 3, 분산 2^2=4가 된다

    • 두 변수는 파라미터 , 최적화된 평균과 분산을 가지도록 조절
반응형

'Boostcamp AI Tech > [week 6-14] LEVEL2' 카테고리의 다른 글

[Week 10] PStage 과정 정리  (0) 2021.10.06
[Week 10] SentencePieceTokenizer  (0) 2021.10.05
[Week 9] F1 Score , Stratified K Fold  (0) 2021.10.01
[Week 6 - Day 3 ] seq2seq  (0) 2021.09.09
[Week 6 - Day 2] RNN, LSTM, GRU  (0) 2021.09.07