[NIPS 2018] Memory Replay GANs: learning to generate images from new categories without forgetting(MeRGAN) 핵심 리뷰

This entry is part 7 of 22 in the series Incremental Learning

내용 요약

Conditional GAN을 사용한 replay 방식의 새로운 incremental learning 방법론을 제안합니다. GAN이 합성한 이전 Task까지의 데이터와 현재 데이터를 동시에 학습하여 새로운 GAN을 학습합니다. 여기에 이전 Task까지의 GAN과 현재의 GAN의 Align을 위한 L2 Loss를 추가해줍니다.

1. 들어가며

이번 글에서는 NIPS 2018에서 발표된 Memory Replay GANs: learning to generate images from new categories without forgetting(MeRGAN) 논문을 리뷰합니다.
이 논문은 MeRGAN이라 불리며 이번 글에서도 MeRGAN이라 지칭하겠습니다.

전체-흐름-속에서-보기
그림1. Incremental Learning 전체 흐름

Incremental Learning을 방법론에 따라 크게 구분하면 위의 그림과 같이 구분할 수 있습니다.

  • Regularization : 이전 task에서 학습한 네트워크의 파라미터가 최대한 변하지 않으면서 새로운 task를 학습하도록 유도
  • Distillation : 이전 task에서 학습한 파라미터를 새로운 task를 위한 네트워크에 distillation
  • Distillation + Memory : 이전 task의 데이터를 소량 메모리로 두고 새로운 task학습 때 활용
  • Distillation + Memory + Bias correction : 새로운 task에 대한 bias를 주요 문제로 보고, 이에 대한 개선에 집중
  • Distillation + Memory + Dynamic structure : task에 따라 가변적으로 적용할 수 있는 네트워크 구조를 사용
  • Distillation + Memory + Generative model : 이전 task의 데이터를 generative model을 사용하여 replay 하는 방식을 사용
  • Dynamic structure : Pruning / Masking 등을 사용하여 task별로 사용할 파라미터 또는 네트워크 등을 정해줌

MeRGAN은 Distillation + Memory + Generative model에 해당하는 방법 중 하나입니다.

2. 제안 방법

먼저 baseline architecture를 살펴본 뒤 저자들이 제안하는 방법을 살펴보겠습니다.

2-1. Baseline architectures

Baseline이 되는 architecture들을 살펴보겠습니다.

baseline-architectures
그림2. baseline architectures

먼저 joint training 방식입니다.
이는 모든 task의 데이터셋을 합쳐 하나의 데이터셋으로 만든 뒤 하나의 Generator / Discriminator / Classifier를 학습하는 방법입니다. 따라서 loss는 다음과 같습니다.

joint-training-loss
그림3. joint training loss

GAN loss (Generator / Discriminator)와 Classification loss를 같이 학습합니다.

다음은 sequential fine tuning 방식입니다.
이는 각각의 task별로 별도의 Generator / Discriminator를 학습시켜줍니다. 이 전 task까지 학습한 파라미터를 받아 새로운 task의 파라미터를 초기화 한 뒤 새로운 task를 학습합니다. 따라서 loss는 다음과 같습니다.

sequential-fine-tuning-loss
그림4. sequential fine tuning loss

Classifier 없이 Generator와 Discriminator만 학습합니다.

다음은 GAN with EWC방식입니다. 기존의 GAN에 regularization 방식인 EWC를 추가한 방식입니다. 따라서 loss는 다음과 같습니다.

GAN-with-EWC-loss
그림5. GAN with EWC loss

2-2. Memory replay generative adversarial networks

다음은 저자들이 제안하는 방식인 MeRGAN 방법을 살펴보겠습니다.

image 15
그림6. MeRGAN의 제안 방법

먼저 그림 a를 보겠습니다.
먼저 task별 데이터셋을 conditional GAN을 학습시켜 이미지를 생성하는 모습을 볼 수 있습니다. 이는 가장 아래의 Gt-1에 해당합니다. 그리고 이렇게 만든 과거 Task 이미지는 St와 합쳐져 새로운 데이터셋 S’t를 구성합니다. 이때의 St는 현재 Task 이미지를 의미합니다. 따라서 새로 구성한 S’t는 GAN이 만든 과거 Task 이미지와 현재 Task 실제 이미지의 합입니다. 이렇게 구성한 데이터로 다시 새로운 GAN을 학습하는 모습입니다. 이에 대한 Loss 수식은 다음과 같습니다.

joint-training-with-replayed-samples-loss
그림7. joint training with replayed samples loss

다음은 그림 b입니다.
현재 task t에 대해 학습한 Generator인 Gt와 t-1 task까지 학습한 Gt-1은 생성하는 이미지가 약간의 차이가 발생합니다. 이를 pixel level로 보정해주기 위해 pixel level L2 loss를 적용합니다. 이를 수식으로 표현하면 다음과 같습니다.

replay-alignment-loss
그림8. replay alignment loss

3. 실험 결과

다음은 이렇게 제안한 방법의 실험 결과를 살펴보겠습니다.

3-1. Digit generation (MNIST, SVHN)

먼저 MNIST, SVHN 데이터셋으로 학습한 digit generation에 대한 실험 결과입니다.

digit-generation-결과
그림9. digit generation 결과

비교 대상은 다음과 같습니다.

  • JT : Joint Training
  • SFT : Sequential Fine Tuning
  • EWC : conditional GAN + EWC
  • DGR : Deep Generative Replay
  • MeRGAN JTR : MeRGAN + Joint Training Replay
  • MeRGAN RA : MeRGAN + Replay Alignment

Table 1의 결과를 보면, 비교 방법들 중 가장 좋은 성능을 확인할 수 있습니다. 또한 생성된 이미지를 보면 SFT는 전 task를 아예 잊어버리는 모습을 볼 수 있습니다. 반면 MeRGAN은 가장 깔끔하게 잘 생성하는 모습을 볼 수 있습니다. Figure 4의 t-SNE 결과를 보아도 위의 가장 좋은 성능을 이해할 수 있습니다. 다른 방법들이 생성한 0은 진짜 0과 다른 곳에 위치하는 반면, MeRGAN이 생성한 0은 task가 진행되어도 여전히 실제 0과 유사한 곳에 위치한 모습입니다.

3-2. Scene generation (LSUN)

다음은 LSUN 데이터셋으로 학습한 scene generation에 대한 실험 결과입니다.

scene-generation-결과
그림10. scene generation 결과

비교 대상은 다음과 같습니다.

  • SFT : Sequential Fine Tuning
  • EWC : conditional GAN + EWC
  • DGR : Deep Generative Replay
  • MeRGAN JTR : MeRGAN + Joint Training Replay
  • MeRGAN RA : MeRGAN + Replay Alignment

우선 table 2를 보면 다른 방법들에 비해 가장 좋은 성능을 확인할 수 있습니다. 실제로 figure 5와 figure 7의 생성한 이미지를 비교해보겠습니다.
다른 모델들은 task가 진행될수록 이전 task에서 학습한 지식을 잊어버리고 새로운 데이터에 동조해버리는 현상을 볼 수 있습니다. 예를 들어, figure 7의 task2에서 kitchen을 학습했지만 task 3에서 church를 학습하면서 bedroom이 교회 모습으로 변하거나 하늘색이 점점 추가되는 모습으로 확인할 수 있습니다.
반면 MeRGAN-RA 방법은 task가 진행되어도 task 1에서 학습한 bedroom의 지식을 거의 그대로 유지하는 모습입니다.
이에 따라 table 2의 성능도 가장 높게 나왔다고 해석할 수 있습니다.

Figure6의 FID는 Frechet Inception Distance의 약자로서, GAN으로 생성한 이미지의 quality와 diversity를 평가합니다.

Series Navigation<< [PAML 2017] Learning Without Forgetting (LwF) 핵심 리뷰[ECCV 2018] Piggyback: Adapting a single network to multiple tasks by learning to mask weights 핵심 리뷰 >>
0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x
Scroll to Top