Reduce inference time by distilling your model

In this tutorial, we will show a method to improve the inference time (the time it takes to predict new data) for an already trained model. We will achieve this via model distillation.

In distillation, an existing, bigger model (the teacher model) is used to train a new, smaller model (the student model). The student model generally achieves better quality results using this distillation technique than can be achieved by simply training the student model on the available dataset. In some sense, the teacher model uses its higher understanding of language in general to guide the student on how to solve the task.

Distillation is usually done with the help of a transfer dataset in addition to the normal dataset used for training. The transfer dataset does not have to be labeled, but should ideally be as close to the target domain as possible. For example: When the goal is to distill a model that is used to classify customer reviews of mobile phones from an online shop, using reviews instead of general English text is better. Even better are reviews of mobile phones and ideally the transfer dataset should be reviews of mobile phones from the specific online shop the model will be used for (if enough data is available). Transfer datasets with a size of multiple hundred megabytes are common.

Be aware that distillation can take many hours, and up to multiple days, depending on the size of the transfer datasets, the teacher, and the student model. Therefore, this method is more suited for models which will be put into day to day production use, where a lot of data has to be processed.

Teachers and Students

As a general rule of thumb, the teacher should be the best and biggest model one can train, and the student should be the smallest model which still gives acceptable results. Even though the student will perform better on the task than if it were simply trained directly, some degradation in prediction quality is usually unavoidable.

For this tutorial, we will use the previously trained bert-base-uncased model, which was trained for the second tutorial (“Train a model to label reviews from the GooglePlay store”), as a teacher. Generally, a bigger model than that would be used as a teacher (e.g. a roberta-base or roberta-large) but since we already trained this teacher in a previous tutorial, we will re-use it.

As a student, we will use an albert-base-v2#cnn model. CNN models were introduced specifically for distillation purposes and are much faster (depending on the used hardware and the dataset, a speedup of 20x). The albert-base-v2 part before the #cnn specifies that the token embeddings of albert-base-v2 should be used for the CNN model.


%load_ext autoreload
%autoreload 2
%load_ext tensorboard

!pip install gdown -q
!pip install pandas -q
!pip install google-play-scraper -q

import autonlu
from autonlu import Model
import pandas as pd
import numpy as np
import gdown
import pickle
from http.client import RemoteDisconnected
from time import sleep

from google_play_scraper import reviews
from import tqdm
User name/Email: admin
Password: ········

We will continue with the model from tutorial 02 (training a label task) as a teacher. So you will have to have run this tutorial to have the model available for loading.

teacher = Model("googleplay_labeling")
Model googleplay_labeling loaded from local path successfully.

Getting labeled data

[5]:"", ".cache/data/googleplay/")"", ".cache/data/googleplay/")

df = pd.read_csv(".cache/data/googleplay/reviews.csv")

def to_label(score):
    return "negative" if score <= 2 else \
           "neutral" if score == 3 else "positive"

X = [x for x in df.content]
Y = [to_label(score) for score in df.score]

#X, Y, valX, valY = autonlu.split_dataset(X, Y, split_at=0.1)
To: /home/paethon/git/autonlu/tutorials/.cache/data/googleplay/apps.csv
100%|██████████| 134k/134k [00:00<00:00, 2.61MB/s]
To: /home/paethon/git/autonlu/tutorials/.cache/data/googleplay/reviews.csv
7.17MB [00:00, 22.5MB/s]

Getting unlabeled data

Let’s start to download unlabeled data that can be used as a transfer dataset. This is rather slow, so for demonstration purposes, the transfer dataset will be quite small. In practice, much larger transfer datasets should generally be used if possible.

[ ]:

transfer_dataset = []

for app in tqdm(apps):
        result, continuation_token = reviews(app, count=50000, lang="en")
    except RemoteDisconnected:
        print("Remote Disconnected. Sleeping and continuing to next app")
    for r in result:
        if r["content"] is not None and len(r["content"]) > 30:

To be able to follow the training, we will start a tensorboard instance

%tensorboard --logdir tensorboard_logs

Distill into a CNN model with albert-base-v2 token embeddings

One note: Depending on the used transfer dataset, it might happen that the model accuracy actually goes down when this transfer dataset is used. This usually means that it is not close enough to the training dataset to be of much use. Depending on the downloaded reviews, this might be the case in this tutorial and is exacerbated by having a rather small transfer dataset. This does not negatively impact the end-state of the distilled model, it just means that in this case the transfer dataset is not further improving the model and could also be left out to speed up training. Ideally, a better transfer dataset should be used.

student = teacher.distill("albert-base-v2#cnn", X=X, Y=Y, unlabelledX=transfer_dataset)"googleplay_labeling_distill_cnn")
Model albert-base-v2 loaded from Huggingface successfully.
/home/paethon/git/py39env/lib/python3.9/site-packages/torch/cuda/ FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved

Test the resulting model

During training, one can generally achieve ~86% accuracy with the new CNN model.

ret = student.predict([
    "This app is really cool and helpful.",
    "The app is quite ok, some things could be improved.",
    "The app does not work at all."])
['positive', 'neutral', 'negative']
%timeit student.predict(X)
556 ms ± 4.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit teacher.predict(X)
7.18 s ± 27.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

As you can see, we created an almost 13x faster model with an accuracy of ~86% instead of 87.62% of the original teacher model!