[NIPS 2020] Bootstrap Your Own Latent A New Approach to Self-Supervised Learning (BYOL) 핵심 리뷰

This entry is part 10 of 11 in the series Self Supervised Learning

1. 들어가며

이번 글에서는 2020년 NIPS에 발표된 Bootstrap Your Own Latent A New Approach to Self-Supervised Learning 논문을 리뷰합니다. 이 논문은 BYOL 이라는 명칭으로 불리며, 이번 글에서도 BYOL 이라고 지칭하겠습니다.

BYOL은 최근에 등장한 방법론이지만, 이미지 분류, 객체 인식 등 다양한 비전 문제에서 뛰어난 성능을 보이고 있습니다. BYOL은 라벨이 없는 데이터에서도 효과적인 특징을 추출할 수 있는 방법을 제시합니다. 이 글에서는 BYOL의 기본 아이디어, 그 특징, 그리고 실제 파이썬 코드로 어떻게 구현되는지에 대해 자세히 알아보겠습니다.

딥러닝에 대한 기초 지식이 있는 분들에게는 이 글이 BYOL과 자기지도학습에 대한 깊은 이해를 제공할 것입니다. 상세한 설명과 코드 예시를 통해, BYOL의 작동 원리와 그 의미를 명확하게 이해할 수 있을 것입니다. 그럼 지금부터 시작해보겠습니다!

2. 문제 제기

먼저 기존 방법의 문제를 살펴보죠. 여기서 기존 방법이란 Contrastive Learning 방식의 Self Supervised Learning 방법인데요. 대표적으로 SimCLRMoCo 등이 있습니다. Contrastive Learning 은 Positive Pair 는 끌어당기고, Negative Pair는 밀어내는 방식으로 학습합니다. 따라서 일관된 방향으로 학습하기 위해서는 아주 많은 양의 Negative Sample 들이 필요하죠. 밀어내는 방향이 계속 바뀌면 일관된 학습을 할 수 없으니까요.

이에 SimCLR에서는 아주 큰 Batch Size를 사용합니다. 최대한 많은 Sample 을 하나의 Batch 에 포함시켜 일관된 학습 방향을 유도하겠다는 것이죠. 이로 인해 엄청나게 많은 양의 GPU를 사용해야만 학습이 가능하죠.

반면 MoCo는 다른 방식으로 문제 해결을 시도하는데요. 바로 모델을 천천히 업데이트 해주는겁니다. 그리고 아주 많은 Sample 들을 Memory Bank에 저장해놓고 비교하죠.

BYOL 에서는 SimCLR, MoCo와는 다른 방식으로 문제를 해결하려고 합니다. 애초에 많은 Negative Sample 없이 Self Supervised Learning 방식으로 학습할 수는 없을까요?

3. Motivation

BYOL의 제안 방법을 살펴보기 전에 재미있는 실험 결과를 살펴보고 갈게요.

그림1. Motivation
그림1. Motivation

위와 같이 두개의 Network로부터 ImageNet 성능을 측정하는 실험을 할겁니다. 하나는 Target Network라고 부르고, 다른 하나는 Online Network라고 부르겠습니다. 두 네트워크 모두 Feature Extractor의 가중치는 초기화 상태로 고정합니다. 그리고 Classifier만 학습할거에요. 이때 Target Network는 ImageNet Label을 사용하여 MLP를 Cross Entropy로 학습해줍니다. 이렇게 학습한 Target Network의 최종 성능은 1.4% Accuracy 였습니다. Feature Extractor 없이 Classifier만 학습했기에 낮은 성능을 보였죠.

재미있는건 Online Network의 성능인데요. Online Network는 Target Network와 달리 바로 ImageNet Label을 사용하여 학습하지 않습니다. 그 대신 먼저 Target Network의 Prediction을 따라하도록 학습하죠. 이렇게 1차 학습을 하고 난 뒤에 최종적으로 Target Network와 동일하게 ImageNet Label을 Cross Entropy로 학습했습니다. 이렇게 학습한 Online Network의 최종 성능은 몇이었을까요? 놀랍게도 18.8% Accuracy를 보였습니다.

생각해봐야 할 포인트는 이겁니다. 왜 Online Network의 성능이 Target Network의 성능보다 훨씬 좋았을까요? Target Network와의 차이라고는 본격적인 학습을 하기 전에 Target Network의 Prediction을 따라하도록 Pretrain 한것 뿐인데요. 심지어 Target Network의 Prediction은 1.4% Accuracy 밖에 되지 않는 아주 수준 낮은 대답이었는데도 말이죠.

위의 관찰로부터 이런 가정을 할 수 있겠네요. 완벽한 정보는 아니더라도, 나보다 조금이라도 더 많이 알고 있는 선생님으로부터 배우는건 어쨌든 도움이 될겁니다. 그럼 선생님을 조금씩 업그레이드 하면서 배우면 어떨까요?

그림2. BYOL 아이디어
그림2. BYOL 아이디어

이렇게 선생님으로부터 학습해서 더 많이 알게 된 내가 오늘의 선생님이 되어 오늘의 나를 가르쳐 주는거죠. 그럼 선생님은 계속해서 나보다 수준이 높은 상태 일거고, 나는 그럼 계속해서 선생님으로부터 배우며 성적이 오를 수 있을겁니다.

4. BYOL

그럼 이제 이러한 아이디어를 어떻게 구현했는지 BYOL의 제안 방법을 살펴보겠습니다.

4-1. 핵심 아이디어

핵심 아이디어는 이렇습니다. 전체 구조는 Target Network와 Online Network를 구성해주는데요. Target Network는 Online Network를 가르쳐주는 역할을 할 겁니다. 따라서 Online Network는 Target Network를 따라하도록 학습할 겁니다. 그리고 이렇게 학습한 지식을 Target Network에게 업데이트 해주어 Target Network의 수준을 더 올려줄겁니다. 이 과정을 반복하는거죠.

4-2. Architecture

위의 아이디어를 그대로 구현한 BYOL의 Architecture는 다음과 같습니다.

그림3. BYOL Architecture
그림3. BYOL Architecture

Online Network가 Target Network의 대답을 따라하도록 학습하는 모습을 표현하고 있습니다. 이를 더 구체적인 Network 형태로 표현하면 이렇습니다.

그림4. BYOL Network Architecture
그림4. BYOL Network Architecture

Online Network는 Target Network 대답을 따라하도록 학습하고요. 이렇게 학습한 지식을 Target Network에 전달해주는 모습을 볼 수 있죠. 이렇게 Target Network를 업데이트 하는 방법은 MoCo와 동일한 컨셉입니다.

5. 파이썬 구현

이번 챕터에서는 파이썬을 사용하여 BYOL을 직접 구현해봅니다. 이를 통해 위에서 살펴본 BYOL의 방법을 정확하게 이해해봅니다.

5-1. Import Module

먼저 필요한 Module들을 Import 해줍니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

5-2. MLP Head Class 정의하기

이어서 MLP Head를 정의해줍니다. MLP Head는 BYOL의 Encoder들에 각각 들어가 Feature를 Projection 해주는 역할을 수행합니다.

class MLPHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=4096):
        super(MLPHead, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        return self.Block(x)

5-3. BYOL Class 정의하기

다음은 최종적으로 BYOL Class를 정의해줍니다. Online Encoder와 Target Encoder를 각각 정의해주고요. 입력값에 대해 각각의 Encoder로 부터 추출한 Feature를 사용하여 Loss를 추출하는 모습을 볼 수 있습니다.

class BYOL(nn.Module):
    def __init__(self, base_encoder, projection_dim=256):
        super(BYOL, self).__init__()
        self.online_encoder = nn.Sequential(
            base_encoder,
            MLPHead(base_encoder.fc.in_features, projection_dim)
        )
        self.target_encoder = nn.Sequential(
            base_encoder,
            MLPHead(base_encoder.fc.in_features, projection_dim)
        )
        
        # Initialize target encoder with online encoder weights
        self.target_encoder.load_state_dict(self.online_encoder.state_dict())
        for param in self.target_encoder.parameters():
            param.requires_grad = False
            
    def forward(self, x1, x2):
        z1_online, z2_online = self.online_encoder(x1), self.online_encoder(x2)
        with torch.no_grad():
            z1_target, z2_target = self.target_encoder(x1), self.target_encoder(x2)
        
        loss = self.loss_fn(z1_online, z2_target.detach()) + self.loss_fn(z2_online, z1_target.detach())
        return loss
    
    def loss_fn(self, x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)
    
    def update_moving_average(self, beta=0.99):
        for online_params, target_params in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target_params.data = beta * target_params.data + (1.0 - beta) * online_params.data

6. 실험 결과

이번에는 실험 결과를 통해 BYOL의 효과를 살펴보겠습니다.

6-1. Linear Evaluation on ImageNet

먼저 ImageNet에 대한 Linear Evaluation 성능을 보겠습니다.

그림5. ImageNet Linear Evaluation
그림5. ImageNet Linear Evaluation

BYOL은 SimCLR, MoCo 등 기존 방법들보다 좋은 성능을 보여줍니다.

6-2. Transfer to Other Classification Tasks

다음은 다른 Classification 데이터셋에 대한 Linear Evaluation 성능을 비교해 보겠습니다.

그림6. Other Classification Task Result
그림6. Other Classification Task Result

많은 데이터셋에서 BYOL은 SimCLR 뿐만 아니라 Supervised 방식보다도 높은 성능을 보입니다.

6-3. Transfer to Other Vision Tasks

이번에는 Classification이 아닌 다른 Vision Task에 대한 성능을 비교해 보겠습니다.

그림7. Other Vision Task Result
그림7. Other Vision Task Result

BYOL은 SimCLR, Supervised 방식보다 좋은 성능을 보입니다.

7. Ablations

다음은 Ablation 실험 결과를 살펴보겠습니다.

7-1. Batch Size

먼저 Batch Size의 변화에 따른 성능 변화를 살펴보겠습니다.

그림8. Ablation - Batch Size
그림8. Ablation – Batch Size

예상대로 SimCLR는 Batch Size가 줄어듦에 따라 급격한 성능 저하를 보이는데요. BYOL은 상대적으로 Batch Size에 둔감한 모습을 보입니다.

7-2. Augmentations

다음은 Augmentation에 따른 성능 변화를 살펴보겠습니다.

그림9. Ablation - Augmentations
그림9. Ablation – Augmentations

마찬가지로 SimCLR는 Augmentation에 따라 많은 성능 차이를 보입니다. 반면 BYOL은 Augmentation의 변화에 비교적 둔감한 모습을 보여줍니다.

7-3. Bootstrapping

마지막으로 BYOL에 영감을 주었던 Bootstrapping 실험 결과입니다. Target Network를 변화시키는 방법에 따른 성능 변화를 살펴보겠습니다.

그림10. Ablation - Bootstrapping
그림10. Ablation – Bootstrapping

Constant Random Network는 Target Network를 업데이트하지 않는 방법을 의미합니다. Motivation에서 살펴봤듯이 이때의 성능은 18.8%가 나옵니다.

실험 결과를 보면 Moving Average의 정도 (𝜏)가 적당할 때 (0.99) 성능이 가장 좋은 모습을 볼 수 있습니다. 너무 빨리 업데이트해도 (0.9), 너무 천천히 업데이트 해도 (0.999) 성능이 안좋죠.

8. 의의

다음은 BYOL 논문의 의의에 대해 정리해보겠습니다.

첫 번째 의의는 Positive / Negative Pair 없이 사용할 수 있는 Self Supervised Learning 방법을 제안했다는 것입니다. 대부분의 자기지도학습 방법은 양성 쌍과 음성 쌍(negative pair)을 사용하여 모델을 학습시키지만, BYOL은 이러한 쌍을 명시적으로 사용하지 않습니다.

두 번째 의의는 “타겟 네트워크”(target network)의 도입입니다. 이 타겟 네트워크는 주 네트워크(online network)의 가중치를 이용해 업데이트되며, 이 두 네트워크의 출력을 비교하여 손실을 계산합니다. 이 과정은 네트워크가 더 일반화된 특징을 학습하도록 도와줍니다.

세 번째로, BYOL은 다양한 아키텍처와 함께 사용할 수 있습니다. 기존의 자기지도학습 방법은 특정 아키텍처에 의존적인 경우가 많았지만, BYOL은 이러한 제약을 크게 줄여줍니다.

이러한 특징들 덕분에 BYOL은 레이블이 없는 데이터에서도 뛰어난 성능을 보이며, 다양한 응용 분야에서 활용될 수 있습니다. 이는 자기지도학습의 새로운 패러다임을 제시하며, 레이블링 비용이 높은 현실 세계의 문제에 적용할 수 있는 강력한 방법론을 제공합니다.

9. 마치며

이번 글에서는 자기지도학습의 한 분야인 BYOL에 대해 알아봤습니다. BYOL의 핵심 원리와 그 특징, 그리고 실제 파이썬 코드로 어떻게 구현되는지에 대해 알아보았습니다.

BYOL의 의의는 라벨이 없는 데이터에서도 효과적으로 특징을 추출할 수 있다는 점에 있습니다. 이로 인해, BYOL은 다양한 비전 문제에서 뛰어난 성능을 보이고 있으며, 자기지도학습의 새로운 가능성을 열고 있습니다.

Series Navigation<< [NIPS 2020] Supervised Contrastive Learning 핵심 리뷰[ICLR 2021] PROTOTYPICAL CONTRASTIVE LEARNING OF UNSUPERVISED REPRESENTATIONS (PCL) 핵심 리뷰 >>
5 1 vote
Article Rating
Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x
Scroll to Top