[21′ Machine Learning] Density-based weighting for imbalanced regression

This entry is 1의 5 in the series Data Imbalance

1. 들어가며

이번 글에서는 21년 Machine Learning 저널에 발표된 Density-based weighting for imbalanced regression 논문의 핵심 내용을 정리해보겠습니다. 이 논문을 부르는 공식적인 별칭은 없는데, 핵심 제안 방법이 loss 가중치를 빈도수에 맞게 재조정하는 방식이므로 이번 글에서는 이 논문의 제안 방법을 reweighting 이라고 부르도록 하겠습니다.

Reweighting은 long tail regression 문제를 다루고 있습니다. long tail regression 문제란, long tail classification 문제와 마찬가지로, 학습 데이터 분포가 고르지 않은 regression 문제를 의미합니다. 예를 들어 사람 얼굴을 입력 받아 나이를 맞추는 문제라고 가정했을때, 20~50대의 데이터는 많은 반면 70~80대 데이터가 굉장히 적은 경우라고 할 수 있습니다.

이러한 long tail 데이터 분포는 딥러닝 모델이 학습하기가 어렵습니다. 가장 직관적으로 이해하기 쉬운 원인은, 학습 loss를 줄여야 하는 모델 입장에서 생각해보면, 빈도수가 많은 데이터에 대해서만 잘 맞추면 되기 때문이죠. 예를 들어 100개 문제 중 1번 정답은 99개, 2번 정답은 1개 밖에 없는 상황을 생각해보면, 모델 입장에서는 괜히 어렵게 공부를 하느니 1번으로 모두 찍는게 loss를 낮추기에 가장 수월한 방법입니다.

2. 기존 방법들

따라서 이러한 long tail 분포를 잘 학습하기 위한 다양한 방법들이 연구되어 왔습니다. 특히 class imbalance classification 문제가 활발히 연구되어 왔는데요. 대표적으로 focal loss등 빈도수에 따른 가중치 조절 방식을 들 수 있습니다.

반면 regression에서는 long tail 문제에 대한 연구가 classification 만큼 활발히 진행되지 않았는데요. 이는 regression과 classification의 문제 특성 차이에서 기인한다고 할 수 있습니다. classification은 말 그대로 class를 구분하도록 학습해야 하는 문제로서, 클래스 별 빈도수를 파악하기 쉽습니다. 반면 예측 값이 class가 아닌 연속적인 형태의 값을 갖는 regression 문제의 경우 무엇이 희귀하고, 무엇이 빈번한지를 정의하기가 쉽지 않죠. 이러한 특성 차이로 인해 classification과 달리 long tail regression 문제는 활발한 연구가 진행되지 않았습니다.

그래도 대표적인 long tail regression 방법은 SMOGN을 들 수 있습니다. SMOGN은 sampling 방식에 변형을 주는 방법을 제안했는데요. 앞서 설명했듯 long tail 문제의 핵심 원인은 ‘빈도수 차이’ 에 있습니다. 빈번한 정보는 모델이 빠르게 학습하는 반면, 드문 정보는 모델이 잘 학습하지 못한다는 문제가 있죠. 따라서 SMOGN은 빈도수가 작은 정보들끼리의 inter/outerpolation을 통해 새로운 데이터를 만들어 빈도수를 맞춰주는 아이디어를 제안합니다. 문제의 원인을 직접적으로 해결하기 위한 시도지만, 매 학습 epoch마다 kNN을 통해 낮은 빈도수 데이터들끼리 inter/outerpolation을 수행함으로 인해 계산량이 매우 많아진다는 한계가 존재했죠.

3. 제안 방법

이에 Reweighting에서는 Sampling은 건들지 않고, loss의 가중치만 빈도수에 따라 다르게 주는 방법을 제안합니다.

그림1. Reweighting 핵심 방법
그림1. Reweighting 핵심 방법

SMOGN이 빈도수에 따른 데이터 sampling을 조작했다면, Reweighting에서는 CPU 연산량의 제약이 없는 loss 가중치를 조작하는 방법을 선택한거죠. 높은 빈도의 데이터는 모델이 학습하기 쉬우니 작은 가중치로, 낮은 빈도의 데이터는 모델에게 더 큰 임팩트를 주기 위해 큰 가중치로 학습해주자는 아이디어 입니다.

그림2. Reweighting 전체 프로세스
그림2. Reweighting 전체 프로세스

전체 프로세스는 위 그림과 같습니다. 먼저 학습 데이터로 Kernel Density Estimation을 수행해서 전체적인 데이터 분포를 추정합니다. 이를 통해 y값들의 빈도수를 알 수 있죠. 그리고 빈도수에 반비례하는 가중치를 계산해줄겁니다. 이렇게 계산된 가중치는 기존 MSE 등의 loss에 곱해져서 최종 loss sum으로 계산해주는 방식입니다.

그럼 KDE를 통해 y값별 분포를 알았으니, loss에 가중치로 반영해주기 위해서는 정규화를 해줘야 하는데요. 다음과 같이 정규화 해줍니다.

그림3. KDE로 정규화해준 빈도수
그림3. KDE로 정규화해준 빈도수

KDE 결과값이 p(Y)일때, minmax scaling을 적용하여 0~1 값으로 정규화해줍니다.

이제 최종적으로 반영할 가중치를 계산해줄건데요. 최종 가중치는 아래 몇가지 조건을 만족해야 합니다. 먼저 당연히 빈도에 반비례하는 값을 가져야 합니다. 그리고 하이퍼파라미터를 통해 세기를 조정할 수 있어야 합니다. 당연히 가중치는 0보다 작거나 같아서는 안됩니다. 학습이 용이하도록 정규화 효과를 위해서는 전체 평균 가중치는 1이 되는게 좋습니다. 이러한 조건들을 만족하도록 구성된 최종 가중치 함수는 아래와 같습니다.

그림4. 최종 가중치 함수
그림4. 최종 가중치 함수

하이퍼파라미터 alpha를 사용하여 전체적인 세기를 조절하는 모습을 볼 수 있고, epsilon을 통해 가중치가 0이 되는 상황을 방지해주고 있습니다. alpha에 따른 전체적인 가중치의 변화는 아래와 같습니다.

그림5. alpha에 따른 가중치 변화
그림5. alpha에 따른 가중치 변화

alpha가 0일때는 빈도수에 따른 가중치 변화가 없는 모습을 볼 수 있습니다. alpha가 커질수록 적은 빈도수에 대해 더 큰 가중치가 반영되는 모습을 볼 수 있습니다. 최종 loss는 아래와 같이 기존 MSE등 loss에 계산된 가중치를 곱해주는 방식으로 적용해줍니다.

그림6. 최종 Dense Loss
그림6. 최종 Dense Loss

4. 실험 결과

이제 이렇게 제안된 방법을 적용한 비교 실험 결과를 살펴보겠습니다.

4.1 Case Study with Synthetic Data

먼저 합성 데이터를 사용해서 제안 방법의 효과를 검증해보겠습니다. 입력은 10차원으로, 각 입력값은 정규분포로 샘플링해줍니다. 출력값은 3 layer MLP로부터 생성해주는데요, 이때 다양한 분포를 갖도록 생성해줍니다.

그림7. 실험에 사용된 다양한 분포들
그림7. 실험에 사용된 다양한 분포들

위 그림과 같이 총 4개의 서로 다른 target 분포를 갖는 합성 데이터를 생성해줍니다. 그리고 빈도수에 따라 bin을 5개로 나누어 성능을 살펴보겠습니다. 이때 빈도수가 가장 낮은 bin이 1번, 빈도수가 가장 높은 bin이 5번입니다. 실험 결과는 아래와 같습니다.

그림8. 분포별/bin 별 성능 비교
그림8. 분포별/bin 별 성능 비교

첫 번째로 발견되는 현상은 alpha=1일때 bin1, bin2 등 rare bin에서의 성능 향상이 두드러진다는 점입니다.

두 번째로 반대로 빈번한 빈 (bin5)에서는 오히려 약간의 성능 저하를 보인다는 점이 눈에 띕니다.

4.2 Comparison with State-of-the-Art

이번에는 당시의 SOTA 모델인 SMOGN과의 성능 비교 실험을 살펴보겠습니다. SMOGN 논문에서 사용한 20개의 real world imbalanced regression 데이터셋을 똑같이 활용하여 성능을 비교해보겠습니다. 이때 dense loss의 alpha는 모두 1로 고정합니다.

그림9. SOTA와 성능 비교
그림9. SOTA와 성능 비교

위 그림은 각 실험에서 bin별 어떤 방법이 가장 좋았는지를 시각화한 표입니다. 마찬가지로 두가지 특성이 눈에 띄는데요.

첫 번째로 희귀 bin에서는 dense loss 방식의 효과가 두드러진다는 점입니다. valilla MSE loss 뿐만 아니라 SOTA 방식인 SMOGN을 크게 앞서는 모습을 보입니다.

두 번째로 마찬가지로 역시 빈번한 bin에서는 기존 방식들보다 오히려 못한 모습을 보여주고 있습니다.

4.3 Statistical Downscaling of Precipitation

이번에는 실제 데이터셋에서의 성능을 살펴보겠습니다. PRISM 데이터셋을 DeepSD 모델로 학습한 결과입니다.

그림10. PRISM 데이터셋 성능 비교
그림10. PRISM 데이터셋 성능 비교

위 실험 결과를 살펴보면 이전 실험들과는 조금 다른 양상을 보이고 있는데요. 이 전 실험들에서는 희귀한 bin에서는 개선되는 성능을, 빈번한 bin에서는 악화되는 성능을 보였는데, 이번 실험에서는 모든 bin에서 성능이 개선되며, 그 개선 폭이 rare bin일수록 더 큰 모습을 보여주고 있습니다.

이는 실제 데이터와 합성 데이터의 분포 차이에서 기인하는것으로 보이는데요. 이번 실험에서 사용한 PRISM 데이터는 강수량 데이터로, 0mm 근처에 대부분의 값이 몰려있는 데이터입니다. 즉 지나치게 long tail 분포를 보이는 상황에서는 모델 성능이 급격하게 나빠지며, dense loss는 이러한 경우에 빈도수가 너무 큰 데이터에 치중되는 형상도 개선해주는 효과가 있는것으로 볼 수 있습니다.

5. 참고자료

Data Imbalance

[21′ ICML] Delving into deep imbalanced regression
0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x
Scroll to Top