오늘은 현재는 다양한 연구 방향성에서 사용되고 있는 Knowledge Distillation 기법에 대해서 시작과 함께 알아보도록 하겠습니다. 그러기 위해서 우선 Knowledge Distillation이 처음 제시될 당시의 목적이었던 Model Compression에 대해서 먼저 간단히 살펴보겠습니다.
Model Compression
- 일반적으로 ML/DL에서 모델이 Performance와 Computational Cost 및 Time은 Trade-off관계인 경우가 많습니다. 그래서 아무리 좋은 performance를 가진 모델이더라도 현실적으로 사용하기에 너무 무겁거나, 학습 및 추론 과정에서 시간이 너무 오래 걸린다면 사용되지 않습니다. 이를 위한 다양한 해결 방법들이 존재하지만, 이번에는 모델을 압축하는 방법에 대해서 이야기해보겠습니다.
- 다양한 Model 압축 방법들( Pruning, Weight Factorizatoin, Weight Sharing, Quantization, Knowledge Distillation 등 … )이 있지만, 목표는 유사합니다. 좋은 성능을 내는 큰 모델을 학습시키고, ( 이때 비용이 많이 들고, 비싼 장비로 학습시키고 하더라도 괜찮다. ) 이러한 큰 모델을 작게 축소해서 작은 모형으로 제품을 만들거나 서비스를 제공함으로써 학습단계에서의 시간적, 비용적 문제는 있겠지만, 추론에서는 시간과 비용이 절약된 모델을 얻어 사용하고자 하는 것입니다.
- 다양한 접근 방법들 중 최근에 부상되는 것은 Quantization입니다. 하지만, 이번에는 바로 이전까지 주목받았고 최근도 꾸준히 주목받고 있는 Knowledge Distillation에 대해서 알아보도록 하겠습니다.
What is Knowledge Distillation?
- Knowledge Distillation은 Model Compression Methods중 하나로 Bucila et al., 2006 에서 처음 제시되었습니다. 이를 좀 더 general하게 ML에서 사용하고자 Hinton et al., 2015 에서 재정립한 것이 흔히 언급되는 Knowledge Distillation입니다.
- Knowledge Distillation은 Teacher-student 구조를 통해서 Model Compression의 목적을 달성합니다. 간단히 정리하자면 작고 단순한 Student Model로 크고 복잡한 Teacher Model과 비슷한 혹은 그 이상의 성능을 내도록 하는 것입니다.
- 즉, 좋은 Performance를 위해 학습된 복잡하고 큰 모델을 지식을 전달해줄 수 있는 Teacher로 정의하고 이러한 Teacher Model의 지식을 전달받은 더 작고 간단한 Student Model을 정의함으로써 Teacher Model이 가지는 복잡한 지식들 중에 현재의 Performance를 달성하기 위한 본질 ( 마치 Deep Learning에서 현재 task 해결을 위한 본질적인 정보인 feature를 추출하는 feature extraction처럼 )적인 지식만을 남기고 나머지는 분리시켜서 날린다는 의미에서 지식 증류 기법 Knowledge Distillation이라고 합니다.
How to Knowledge Distillation Work?
- 이러한 Knowledge Distillation이 실제로 어떻게 동작하는지를 살펴봅시다.
- 먼저 1차적으로 Train data는 Teacher Model을 통해서 학습됩니다. 즉, data는 우선 teacher model에 feeding됩니다.
- 2차적으로는 Train Data를 Teacher Model과 Student Model( distilled model )에 모두 feeding합니다. 다시 말해, Pretrained Teacher Model을 Freeze시키고 Data를 태워 Prediction한 결과(Soft label)와 지식이 증류되기를 원하는 Trainable Student Model의 Prediction 결과(Soft label)가 비슷하도록 Loss를 measure해 학습시키고, 동시에 Student Model 또한 기존대로 Ground Truth에 대해서 학습을 시킵니다.
- 즉, 다음과 같이 Objective Function을 정의할 수 있는 것입니다. $$ \mathcal{L}(x ; W)=\alpha * \mathcal{H}\left(y, \sigma\left(z_s ; T=1\right)\right)+\beta * \mathcal{H}\left(\sigma\left(z_t ; T=\tau\right), \sigma\left(z_s, T=\tau\right)\right. $$
- 각 Notation은 다음과 같은 의미를 담고 있습니다. $$ H(): \:Cross \:entropy \:loss $$ $$ \sigma() : Softmax $$ $$ z_s : \:Output \:logits \:of \:Student \:network $$ $$ z_t : \:Output \:logits \:of \:Teacher \:network $$ $$ y : \:Ground\:truth(one-hot) $$ $$ T: \:Temperature\:hyperparameter$$ $$ \alpha, \beta: Coefficients $$
- Objective Functoin의 2가지 Term, Distillation Loss Term 과 Student Model Loss Term 그리고 각 부분 부분을 위 Figure에 대응시켜봅시다.
- 다음과 같이 Teacher Model이 예측한 결과 logits를 Softmax를 통과시켜 Soft label로 사용하고 이를 Student Model이 예측한 결과 logits를 Softmax를 통과시킨 Soft predictions와 Cross entropy loss로 measure함으로써 distillation loss term으로 사용합니다. 동시에 Student Model 또한 GT에 대해서 학습이 같이 이루어져야 하기에 다음과 같이 Softmax를 거치고 maximum score에 해당하는 label까지 뽑은 hard prediction과 GT인 hard label y와 Cross entropy loss로 measure해 각 loss term을 coefficient로 조절해주도록 loss를 정의해 학습을 하도록 하는 구조를 가지고 있습니다.
Why not Hard Label but Soft Label?
- Soft label & Hard label
- Logit값에다가 Softmax function을 적용해 나온 Probability 값들을 Soft label이라고 하고 이러한 score들의 maximume을 취해서 나온 gt와 같은 label을 Hard label이라고 합니다.
- 저자들을 입력 이미지에 대해 label된 class에도 다른 class에 해당하는 특징 및 정보가 존재할 수 있다고 생각했기에 soft label을 활용함으로써 이러한 정보까지 놓치지 않고 사용하고자 했습니다.
- 예를 들어 다음과 같은 강아지 이미지를 Network에 feeding 했을 때 나온 logit을 softmax를 통과해 soft label로 확인하면 다음과 같은 값들이 나오게 됩니다. 이때 해당 이미지에 고양이가 가질 수 있는 특징들( ex- 입 모양 )과 같은 특징들이 있기에 cat class가 0.06 나올 수 있는 것이고, 이 score들인 soft label에는 이러한 정보들이 포함되어 있기에 이러한 정보가 손실되지 않는 soft label을 사용해 distillation loss를 measure하고자 했습니다.
- 물론 이러한 이야기는 추후 연구들에서 다른 방향으로 접근되기도 합니다.
Deep Thinking with KN work
- 이러한 동작 구조를 예시를 통해서 더 자세히 살펴봅시다.
- Teacher Model이 특정 이미지를 3가지 클래스로 각각 [0.80 0.15 0.05]의 score로 예측했다면, Student Model은 이러한 [0.80 0.15 0.05]를 따라가도록 학습시키는 것입니다. 실제로 논문에서 밝히기로는 잘 학습된 모델을 따라가도록 학습시키는 것이 scratch부터 학습시키는 것에 비해 훨씬 쉽다고 합니다. 어쨌든 이렇게 학습시킴으로써 Student Model은 GT에 대해서도 학습을 하지만, 무작정 GT만을 학습하는 것이 아니라 Teacher Model에서 제시하는 결과에 가까워질 수 있도록 학습되기에 학습 양상이 달라지게 됩니다.
- 이러한 학습 전략은 Teacher Model이 예측한 결과뿐 아니라 Teacher Model이 학습한 정보들 ( ex- inductive bias )도 학습을 하도록 하기에 작고 간단한 Student Model을 그냥 GT만으로 학습을 시킬 때 발생할 수 있는 이러한 부가적인 정보들의 손실 없이 Teacher Model과 비슷한 성능을 낼 수 있고 혹은 더 좋은 성능을 낼 수 있도록 학습이 됩니다.