본문 바로가기

Analysis

[논문리뷰] FixMatch : Simplifying Semi-Supervised Learning with Consistency and Confidence

728x90
반응형

작년 KAMP 경진대회를 준비하며 라벨 데이터가 적은 한계를 극복하기 위한 Semi-supervised Learning(이하 SSL)에 관심이 생겼다. 관련하여 어떤 논문을 리뷰할까 서치하다 FixMatch라는 구글에서 발표한 연구를 알게 되었다. 

최근 딥러닝 업계에서 많이 연구되는 SSL분야는 연구가 많이 진행될수록 모델의 성능이 좋아졌지만 그와 함께 모델의 구조 및 학습 방법 또한 복잡해졌다. FixMatch는 이전에 진행된 최신 SSL 모델 대비 비슷하거나 더 좋은 성능을 가지지만 모델을 단순화 하고 비용을 줄임으로 많은 각광을 받았다는 점이 흥미를 끌었다. 


문제 인식

컴퓨터 비젼에 딥러닝을 사용하며 점점 더 많은 데이터가 필요하게 되었다. 하지만 데이터에 라벨을 붙이는 것은 비용과 공수가 많이 드는 작업이고 특히 전문적인 지식이 필요한 이미지(ex : 의학, 제조 등)에 라벨링 작업을 하는것은 그 비용이 기하급수적으로 증가하게 된다.  이에 SSL(Semi-supervised Learning) 방식에 대한 관심과 인기가 증가하게 되었고 비라벨 데이터가 지도 학습 모델의 성능을 높여주면서 SSL 모델은 점점 더 복잡하게 되었다.

최근 인기있는 SSL방식은 비라벨 데이터를 사용해 라벨을 생성하고 모델을 학습해서 인공 라벨을 예측하는 방식이다. pseudo-labeling과 consistency regulation 방식 등이 있는데 FixMatch에서는 이런 방식을 이용하면서도 점점 증가하는 알고리즘의 복잡성을 타파하고자 하는 명확한 목적이 있다.


FixMatch의 배경이 되는 개념

Consistency Regularization

이미지를 변형하더라도 동일한 이미지에 대한 모델의 결과는 유사하다 

최근 SSL 알고리즘의 중요한 개념이다. 위의 가정을 기본으로 기존 데이터와 augmentation이 적용된 데이터에 대한 모델의 출력값을 차이가 최소가 되도록 학습이 진행된다. 

목적함수는 다음과 같다

$\alpha$와 $p$는 임의의 확률 함수를 뜻한다. 두 변형된 이미지에 대한 모델의 확률값의 차이를 줄이면서 임시 label의 신뢰성을 높인다. 
위의 아이디어를 확장해서 적대적 변환(adversarial transformation)을 사용한다. 첫번째 p는 과거 모델 예측값의 cross-engropy loss를 적용하고 두 번째 p의 input은 강력한 augmentation을 적용한다.  

 

Pseudo-labeling

모델 자신을 이용해서 비라벨 데이터의 인공 라벨을 획득한다

 라벨 데이터로 모델을 학습 후 해당 모델로 비라벨 데이터를 예측한다. 가장 높은 확률을 얻은 클래스를 pseudo label로 지정한다. 즉, 모델의 output인 클래스의 확률값의 argmax를 이용해 'one-hot' 확률 분포를 만드는 것이다. 
 이후 비라벨 데이터를 변형해서 동일 모델로 예측했을 때 예측결과와 pseudo label의 cross-entropy loss를 계산한다. 

목적 함수는 다음과 같다. 비라벨 데이터의 Pseudo label과 예측 결과의 cross-entropy loss를 줄이는 방향으로 학습을 진행한다. 

 

Augmentation 

FixMatch에서 사용되는 augmentation 전략에 대해 미리 간단히 얘기하자면 'weak'와 'strong' augmentation이 있다. 

  • Weak 전략
    • 일반적인 flip-and-shift augmentation을 사용했다. 
    • FixMatch에서는 랜덤하게 전체 데이터셋의 50%를 수평하게 뒤집고(SVHN제외-digit 정보를 포함하기 떄문) 12.5%까지 수평, 수직으로 변환했다.

  • Strong 전략
종류 설명 참고
RandAugment  모든 왜곡의 심각도를 제어하는 크기(magnitude)와 적용할 augmentation개수를 사전에 정한 후 그 범위에서 랜덤하게 샘플링하여 적용

CTAugment 개별적인 변환 크기를 즉시 학습한다. 
ReMixMatch 연구에 소개되었다.
 

 


FixMatch Process

1. Batch 준비

B: 라벨 이미지 데이터 개수

$\mu$B :  비라벨 이미지의 배치 크기 ($\mu$ : 라벨 데이터의 개수와 관련한 hyperparameter)

 

2. 지도학습

라벨 데이터를 이용해 지도학습을 수행한다. 여기서 loss는 일반적인 cross-entropy loss H()를 사용한다. 

$p_b$ : 실제 label
$p_m(y|\alpha (x_b))$ : 전처리가 적용된 이미지를 이용한 예측값
$H( )$ : cross-entropy loss
$l_s$ : 각 배치의 평균  cross-entropy loss

 

3. Pseudo-labeling

 먼저 weak augmentation을 적용한 비라벨 이미지를 2번에서 학습한 모델을 이용해 예측한다. 예측 결과에서 가장 높은 확률을 가진 클래스를 해당 이미지의 pseudo label로 적용한다. 
 동일한 모델에 strong augmentation을 적용하여 동일 모델로 예측한다. 이 예측값과 pseudo label의 차이를 비교한다.

 $q_b$ : weak augmentation을 적용한 비라벨 데이터의 예측값
$\hat{q}_b$ : 예측값에서 가장 높은 확률을 가진 label

 

4. Consistency Regularization

strong augmentation을 적용하여 예측한 결과와 pseudo label의 cross-entropy loss를 계산한다.
기존의 pseudo-labeling과 비슷해 보이지만 차이가 있다. FixMatch에서는 weakly-augmented 이미지를 기반으로 pseudo-label을 계산하고 strongly-augmented 이미지를 기반을 loss를 계산하는 점이다. 

$\tau$ : pseudo-label을 선택할지 결정하는 threshold

 

5. Curriculum Learning

위의 절차를 결합한 FixMatch의 최종 loss는 다음과 같다

$loss = l_s + \lambda_{u} l_u $

$\lambda_u$ : 비라벨 데이터의 가중치와 연관된 고정 hyperparameter

여기서 흥미로운 부분은 $\lambda_{u}$이다. 이전 연구에서는 학습이 진행됨에 따라 이 weight를 증가시켰다. 하지만 FixMatch에서는 weight 증가가 불필요하다는 것을 확인했다. 왜냐면 학습 초반에는 라벨 데이터의 부족으로 모델의 신뢰성이 낮고 $max(q_b)$가 일반적으로 $\tau$보다 작다. 하지만 학습이 진행되면서  $max(q_b) > \tau$인 경우가 많아진다. 이는 비라벨 데이터를 학습에 포함하며 라벨 데이터를 이용한 모델의 신뢰도가 높아지기 때문이다. 

일련의 과정은 자유로운 형태의 커리큘럼 학습(curriculum learning)으로 볼 수 있다. 이것은 마치 아기의 학습에서 간단한 컨셉의 알파벳부터 학습하며 성장하면서 점점 복잡한 형태의 단어, 문장, 에세이를 학습하는것과 같은 개념으로 볼 수 있다. 


실험

비교 알고리즘의 데이터 처리

 

데이터 설명

데이터셋 설명 샘플
CIFAR-10 사이즈 32*32, RGB
class 개수 10개
각 class당 6,000개의 이미지 
(학습 데이터 5,000 / 테스트 데이터 1,000)
CIFAR-100 사이즈 32*32의 컬러 이미지
class 개수 100개 -> 20개의 수퍼 클래스로 분류
각 class당 600개의 이미지(학습 500개 / 테스트 100개)
전체 60,000개

 
SVHN 0~9로 이루어진 집주소 이미지
32*32 RGB
60만장
STL-10 96*96, RGB
class 10개
학습 데이터 5,000(class당 500개)
테스트 데이터 8,000 (class당 800개)
 

 

결과 요약

데이터셋 설명 참고
CIFAR-10 & SVHN 5folds
최신 결과 달성
CIFAR-100 ReMixMatch가 좀 더 나은 결과 달성.
Distribution Alignment(DA, 모델이 모든 클래스를 동일한 분포로 나오게 하는것)이 원임임을 확인, FixMatch와 DA를 결합하여 40.14%라는 더 낮은 에러 확률을 달성함. 

STL-10 STL-10은 비라벨 데이터가 훨씬 많기 때문에 SSL의 평가에 적합한 데이터.
 CTAugmentation 적용 한 모델이 가장 좋은 성능을 보임

 

728x90
반응형