- [ICCV 2015] Unsupervised Visual Representation Learning by Context Prediction 핵심 리뷰
- [CVPR 2016] Inpainting SSL : Context Encoders: Feature Learning by Inpainting 핵심 리뷰
- [ECCV 2016] Zigsaw Puzzle SSL : Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles 핵심 리뷰
- [ICLR 2018] UNSUPERVISED REPRESENTATION LEARNING BY PREDICTING IMAGE ROTATIONS (RotNet) 핵심 리뷰
- [CVPR 2018] Unsupervised Feature Learning via Non-Parametric Instance Discrimination (NPID) 핵심 리뷰
- [ECCV 2018] Deep Clustering for Unsupervised Learning of Visual Features (DeepCluster) 핵심 리뷰
- [PMLR 2020] A Simple Framework for Contrastive Learning of Visual Representations (SimCLR) 핵심 리뷰
- [CVPR 2020] Momentum Contrast for Unsupervised Visual Representation Learning (MoCo) 핵심 리뷰
- [NIPS 2020] Supervised Contrastive Learning 핵심 리뷰
- [NIPS 2020] Bootstrap Your Own Latent A New Approach to Self-Supervised Learning (BYOL) 핵심 리뷰
- [ICLR 2021] PROTOTYPICAL CONTRASTIVE LEARNING OF UNSUPERVISED REPRESENTATIONS (PCL) 핵심 리뷰
1. 들어가며
이번 글에서는 2020년 NIPS에 발표된 Supervised Contrastive Learning 논문을 리뷰합니다. 이름에서 알 수 있지만 사실 이 논문은 Self Supervised Learning 방법은 아닙니다. Supervised Learning 방법에 해당하죠. 하지만 Contrastive Learning 을 사용하는 기존 Self Supervised Learning 방법에서 Label 정보를 사용할 수 있다면 어떻게 성능이 달라지는지 확인할 수 있는 재밌는 방법입니다.
이 글에서는 Supervised Contrastive Learning의 기본 원리부터 파이썬 코드로의 구현, 그리고 이 방법이 어떻게 다양한 응용 분야에서 활용될 수 있는지에 대해 자세히 알아보겠습니다. 이를 통해 레이블을 더 효율적으로 활용할 수 있는 새로운 방법론에 대해 이해해 보겠습니다.
2. 제안 방법
제안 방법은 아주 간단합니다.
2-1. 핵심 아이디어
먼저 Supervised Contrastive Learning 의 핵심 아이디어를 살펴볼게요.
![[NIPS 2020] Supervised Contrastive Learning 핵심 리뷰 1 그림1. Self Supervised Contrastive와의 차이](https://ffighting.net/wp-content/uploads/2023/07/image-91-1024x857.png)
기존의 Self Supervised Contrastive Learning 방법을 생각해보죠. 위 그림의 왼쪽에 해당합니다. 이 방법에서는 나의 Augmentation 된 이미지만 끌어당기고 나머지는 모두 밀어내도록 학습했죠. 그런데 이 부분을 한번 더 생각해 볼게요. 이렇게 학습하면 이 모델은 무엇을 학습하게 될까요? 나와 다른 특성은 모두 밀어내도록 학습되겠죠? 그런데 사실 같은 클래스에 속하는 이미지는 같은 특성으로 인식하도록 학습해야 하잖아요? 즉 나와 다른 이미지더라도 같은 강아지 클래스의 이미지라면 당기도록 학습하는게 맞지 않냐는거죠.
오른쪽 그림을 볼게요. 그래서 Supervised Constrastive 에서는 나의 Augmentation 된 이미지뿐만 아니라 동일한 클래스에 속하는 이미지도 끌어당기도록 학습합니다. 그림을 보면 기존 Self Supervised Contrastive 와 달리 강아지끼리는 모두 끌어당기도록 학습하는 모습을 볼 수 있습니다.
2-2. 수식
이제 수식으로 확인해볼까요? 기존 Self Supervised Contrastive Loss 수식은 다음과 같습니다.
![[NIPS 2020] Supervised Contrastive Learning 핵심 리뷰 2 Self-Supervised-Contrasctive-Loss-수식](https://blog.kakaocdn.net/dn/8eblY/btrjjrYWJci/NkXnK4c6LKbLZpDOUTL5Q1/img.png)
오직 나의 Augmentation 된 이미지만 Positive Pair 라고 말하고 있죠.
반면 Supervised Contrastive Loss는 이렇습니다.
![[NIPS 2020] Supervised Contrastive Learning 핵심 리뷰 3 Supervised-Contrastive-Loss-수식](https://blog.kakaocdn.net/dn/nDa8Q/btrjiZWvTuz/sJpCL3WX3iS9Lq8MJ67JBk/img.png)
‘같은 클래스라면 모두 Positive Pair 야’ 라고 말하고 있죠?
3. 파이썬 구현
이번 챕터에서는 파이썬을 사용하여 Supervised Contrastive Learning 방법을 직접 구현해 봅니다. 이 과정을 통해 위에서 살펴본 내용을 정확히 이해해 봅니다.
3-1. Import Module
먼저 필요한 Module들을 Import 해줍니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
3-2. Supervised Contrastive Learning Class 정의하기
다음은 Supervised Contrastive Learning Class를 정의해줍니다. 기존 Contrastive Learning 과는 달리 Label을 보며 Loss를 계산하는 모습을 볼 수 있습니다.
class SupervisedContrastiveLoss(torch.nn.Module):
def __init__(self, temperature=0.1):
super(SupervisedContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, embeddings, labels):
# Normalize the embeddings
embeddings = F.normalize(embeddings, p=2, dim=-1)
# Compute the similarity matrix
sim_matrix = torch.matmul(embeddings, embeddings.T)
# Create the positive mask
labels = labels.unsqueeze(1)
positive_mask = torch.eq(labels, labels.T).float()
# Create the negative mask
negative_mask = torch.ne(labels, labels.T).float()
# Compute the positive and negative similarity
sim_ij = torch.exp(sim_matrix / self.temperature)
exp_sim_matrix = torch.exp(sim_matrix / self.temperature)
# Remove diagonal elements
ind = torch.eye(labels.size(0)).bool().to(device)
sim_ij.masked_fill_(ind, 0)
exp_sim_matrix.masked_fill_(ind, 0)
# Compute the loss
sum_exp_sim_matrix = torch.sum(exp_sim_matrix, dim=1)
pos_exp_sim_matrix = torch.sum(sim_ij * positive_mask, dim=1)
loss = -torch.log(pos_exp_sim_matrix / sum_exp_sim_matrix)
return torch. Mean(loss)
4. 실험 결과
Supervised Contrastive Learning 방식의 실험 결과를 살펴보겠습니다.
4-1.Classification 성능 테스트
첫 번째로 classification 성능 테스트 결과를 살펴보겠습니다.
![[NIPS 2020] Supervised Contrastive Learning 핵심 리뷰 4 classification-성능-테스ᄐ](https://blog.kakaocdn.net/dn/cLxHH0/btrj2fErbvC/HVHj43gyHKYH8bMSUIltXK/img.png)
Self Supervised Contrastive 방식인 SimCLR보다 높은 성능을 내고 있네요. 심지어 전통적인 Classification Loss인 Cross Entropy 보다도 좋은 성능을 낸다는 점이 놀랍습니다.
4-2. Robustness
두 번째로, 실험 결과에 따르면 Self Supervised Contrastive 방식은 Corruption에 더 강인하며 데이터가 더 적을 때도 좋은 성능을 냅니다.
![[NIPS 2020] Supervised Contrastive Learning 핵심 리뷰 5 그림5. Robustness Test](https://blog.kakaocdn.net/dn/QAy6I/btrj11Grl7i/nMTi92kqFvUIHSIzLSDPJ0/img.png)
위 실험 결과를 보면 Cross Entropy 방식과 비교했을 때 mCE이 더 낮은 모습을 볼 수 있습니다. 오른쪽 그래프에서도 Corruption 수치가 올라감에 따라 Cross Entropy 모델들보다 성능 저하가 덜한 모습을 보입니다.
5. 의의
다음은 Supervised Contrastive Learning 논문의 의의를 살펴보겠습니다.
첫 번째는 Contrastive Learning과 Supervised Learning을 결합한 학습 방법을 제안했다는 점입니다. Supervised Contrastive Learning은 Contrastive Learning의 장점을 Supervised Learning에 적용합니다. 이로써 레이블이 있는 데이터를 더 효율적으로 활용할 수 있습니다. Contrastive Learning은 일반적으로 레이블이 없는 데이터에서 잘 작동하지만, Supervised Contrastive Learning은 레이블을 활용하여 같은 클래스의 샘플을 가깝게 배치하고 다른 클래스의 샘플을 멀게 배치합니다.
두 번째 의의는 Feature 추출 품질을 향상시켰다는 점입니다. Supervised Contrastive Learning은 레이블 정보를 사용하여 특성 공간을 더욱 의미있게 구성합니다. 이는 분류, 검색, 세분화 등 다양한 Down Stream 작업에서 더 좋은 성능을 달성할 수 있음을 의미합니다. 특히, 레이블이 있는 적은 양의 데이터가 있을 때, Supervised Contrastive Learning은 이를 최대한 활용하여 특성을 더 잘 추출할 수 있습니다.
이 두 가지 의의는 Supervised Contrastive Learning이 단순히 레이블을 활용하는 것을 넘어, Contrastive Learning과 Supervised Learning을 유기적으로 결합하여 더 효율적인 특성 추출과 더 높은 성능을 달성할 수 있게 하는 중요한 기술적 발전이라고 할 수 있습니다.
6. 마치며
Supervised Contrastive Learning은 레이블을 활용하여 특징 추출을 더욱 효율적으로 만드는 방법입니다. 이 방법은 특별한 Loss 함수를 사용하여, 같은 레이블을 가진 데이터끼리는 가까워지고, 다른 레이블을 가진 데이터끼리는 멀어지도록 학습합니다.
이러한 접근법은 모델이 더 강력하고 일반화된 특징을 추출하도록 도와줍니다. 이 글을 통해 이 방법의 중요성과 유용성을 이해하셨기를 바랍니다. Supervised Contrastive Learning은 레이블이 있는 데이터에서 특히 뛰어난 성능을 보이며, 다양한 응용 분야에서 활용될 수 있습니다.