Distilling the Knowledge in a Neural Network

SoonChang
5 min readOct 24, 2021

In this post, we will go through “Distilling the Knowledge in a Neural Network” by Hinton et al. In this paper, the authors aimed to find a way to transfer the knowledge learned by a large neural network model to a smaller neural network model without compromising the performance of the model.

There is a trend in Deep Learning where researchers trained larger and larger model so they can achieve state-of-the-art performance and published the results. However, these trained models often are too large for practical deployment on different devices. The authors refer to these larger model as cumbersome model. A cumbersome model can either be an ensemble model which consists of many networks which average their output as final prediction or a really large single model. Knowledge distillation seeks to transfer the knowledge obtained by a cumbersome model to a smaller model in order to offer savings in computation and memory. This smaller model is also known as distilled model.

What is knowledge?

The knowledge learned by a neural network is often associated with the trained parameters of the network. In this work, the authors offered an alternative perspective. The authors suggested knowledge learned by a neural network can be view as a learned mapping between input and output of the network. What does this means? Neural network usually converts its final layer output (a vector called logits) to class probabilities using a softmax function. The prediction vector is a list of class probabilities assigned by the network to each class. These class probabilities sum up to one. Figure 1 shows an example of made-up class prediction. There is knowledge hidden in the prediction vector. A dog image is more likely to be misclassified as a wolf or cat than a fish.

Figure 1: Predicted class probabilities for a dog image (artificial)

The authors found that using the class probabilities predicted by a cumbersome model as learning targets, it is possible to train a simpler and smaller network without compromising the accuracy. These learning targets (prediction by cumbersome model) are called soft targets. Whereas, the actual targets are called hard targets.

Softmax with temperature

The authors proposed softmax with temperature T. The standard softmax function is T = 1. The output of standard softmax is “sharp”, pushing values of other class probabilities very close to zero. As a result, learning using these sharp prediction as soft targets does not differ much from learning with actual targets. By increasing T, the softmax produces a softer probability distribution over classes.

Figure 2: Softmax with temperature

A better way to train the distilled model is with both actual targets and soft targets. The loss function is a weighted average of two cross entropies. The first cross entropy is with the soft targets and the second cross entropy is with the actual targets (T=1).

Experiments on MNIST

MNIST is a dataset of 28×28 images of digits (0–9) drawn using mouse. The images are augmented by jittering the image up to 2 pixels in any direction. The following network architectures are used:

  1. Cumbersome model / Big net: A single large neural network with 2 hidden layers of 1200 rectified linear units regularized using dropout.
  2. Small net: A single neural network of 2 hidden layers of 800 rectified linear units.

We can calculate the number of parameters (assuming weights only)

Big net: 28×28×1200 + 1200×1200 + 1200 × 10 = 2392800

Small net: 28×28×800+ 800×800 + 800 × 10 = 1275200

The small net is about 53.29% the size of big net considering weights only.

Following are the performance on test set:

  • Big net: 67 test errors
  • Small net without regularization: 146 test errors
  • Small net trained with soft targets (T=20): 74 test errors

This experiment shows that small model learned directly using training set with actual targets performed much worse than small model learned using soft targets.

Besides that, the authors also conducted an experiment by omitting all training examples of digit 3. The distilled model had never seen the digit 3. Yet by increasing the bias for class 3 by 3.5, the model gets 98.6% accuracy on the 3s in test set. This means that even without seeing 3 in training, the distilled net can learn some information of 3 based on the soft targets.

Experiments on speech recognition

The authors also conducted experiments on speech recognition. In this experiment, the authors show that distilling an ensemble of models into one model works significantly better than a model of the same size trained directly using the same training set.

Deep neural network (DNN) maps features converted from speech waveform to a distribution over the discrete states of a Hidden Markov Model. The decoder then finds a path over HMM states that’s the best compromise between using high probability states and producing a transcription that is possible under the language model. Don’t worry if you don’t get this part, you can think of DNN as a feature extractor which output a better learned representation of the input.

Network architecture of DNN: 8 hidden layers with 2560 rectified linear units for each layer. The final softmax layer with 14000 labels.

Data: 2000 hours of English speech data, about 700M training examples.

Table 1: Frame classification accuracy and WER

Table 1 is a screenshot of the performance of the baseline model, ensemble model and distilled single model from the paper. WER refers to word error rate for the prediction.

Baseline: One model trained using training data directly

10xEnsemble: 10 separate models (same network architecture as baseline)

Distilled single model: One model trained using 10xEnsemble predictions as soft targets

The result in Table 1 shows that the distillation approach can extract more useful information from training set than simply using the actual targets to train a single model. This is demonstrated by the performance of distilled model which is better than single model in both test frame accuracy and WER.

Reference:

[1]Distilling the Knowledge in a Neural Network [https://arxiv.org/abs/1503.02531]

--

--

SoonChang

PhD in AI. Interested in Computer Vision, Deep Learning and Network Pruning