All Articles

My Understanding of Knowledge Distilling

knowledge distilling

Introduction

Knowledge Distilling is a widely used technique in machine learning. It is a method to compress a large model into a smaller one.

The large model, the teacher model, is usually a deep neural network with many parameters. The smaller model, which we call the student model, is a shallow neural network with fewer parameters. The motivation behind knowledge distilling is to transfer the knowledge learned by the teacher model to the student model so that the student model can achieve similar performance to the teacher model while being more computationally efficient.

The traditional way to train a neural network is to minimize the loss function between the predicted output and the ground truth. However, in knowledge distilling, we also introduce an additional loss function to minimize the difference between the output of the teacher model and the student model. This additional loss function is called the distillation loss. With the distillation loss, the student model can learn not only from the ground truth labels but also from the soft labels generated by the teacher model.

Distilling Loss Function

Annotations

  • CC : loss
  • TT : temperature
  • pp : logits of the student model
  • qq : logits of the teacher model
  • zz : output vector of the student model
  • vv : output vector of the teacher model

Distilling Loss Function

Czi=1T(qipi)=1T(ezi/Tj=1Nezj/Tevi/Tj=1Nevj/T) for i=1,2,...,N\mathbf{\frac{\partial C}{\partial z_{i}}} = \frac{1}{T}(q_{i} - p_{i}) = \frac{1}{T}(\frac{e^{z_{i} / T}}{\sum_{j=1}^{N}e^{z_{j} / T}} - \frac{e^{v_{i} / T}}{\sum_{j=1}^{N}e^{v_{j} / T}}) \ for\ i=1,2,...,N

When j=1Nvj=0\sum_{j=1}^{N}v_{j} = 0 and j=1Nzj=0\sum_{j=1}^{N}z_{j} = 0, and the temperature TT is big enough, the distillation loss function can be approximated as:

Czi1T(1+zi/TN+j=1Nzj/T1+vi/TN+j=1Nvj/T)=ziviNT2 for i=1,2,...,N\mathbf{\frac{\partial C}{\partial z_{i}}} \approx \frac{1}{T}(\frac{1 + z_{i} / T}{N + \sum_{j=1}^{N}z_{j} / T} - \frac{1 + v_{i} / T}{N + \sum_{j=1}^{N}v_{j} / T}) = \frac{z_{i} - v_{i}}{NT^{2}} \ for\ i=1,2,...,N

Published Mar 8, 2016

Flying code monkey