Distilling GPT with PyTorch and Transformers
This tutorial explains how to distill GPT2 into a smaller distilgpt2
model. The entire pipeline includes:
- Setting up a stable training environment.
- Loading teacher (
gpt2
) and student (distilgpt2
) models. - Using data augmentation (if desired).
- Applying improved preprocessing steps.
- Implementing label smoothing.
- Using a custom distillation loss function.
- Training and evaluating the distilled model.
Throughout the tutorial, you'll see how to train the student model using teacher outputs (knowledge distillation) while also retaining some direct supervision from ground-truth labels (hard loss). This balance is configured with a parameter alpha
.
Prerequisites
- PyTorch (for building and training neural networks)
- Hugging Face Transformers (for pretrained transformer models)
- Datasets (to load and process datasets)
- Basic familiarity with Python and deep learning concepts
Ensure you have installed these packages before proceeding:
Full Code
Below is the complete code for the distillation workflow. You can save it in a file named Distill_GPT.py
(or any name you prefer) and run it. Comments in the code are partly in Chinese, but each section is explained in English below the code blocks.
Click to expand the code
Explanation of Key Steps
1. Device Setup
We detect whether a GPU is available using torch.cuda.is_available()
:
This ensures our model will train on GPU if available, otherwise on CPU.
2. Choosing Stable Base Models
We define the teacher and student models:
Feel free to switch them to other models (like GPT-Neo or GPT-J) if you wish, as long as they match a causal language modeling architecture.
3. Loading Models in FP32 & Adjusting Dropout
We load both teacher and student in FP32 for more stable training. We also increase the dropout for the student model to reduce overfitting:
4. Loading the Tokenizer
We use the same tokenizer as our teacher model for consistency:
We set pad_token
to eos_token
to ensure all sequences are padded consistently.
5. Data Augmentation
Here, you can optionally include text augmentation techniques (synonym replacements, random insertions, etc.). Currently, this function returns the original text.
6. Improved Data Preprocessing
We perform minimal text cleaning and tokenization:
We remove texts that are too short (less than 10 characters). This helps avoid training with trivial examples.
7. Loading & Splitting the OpenWebText Dataset
The dataset is split into training (90%) and testing (10%). We then map our tokenize_function
to tokenize the entire dataset.
8. DataLoader Configuration
We create PyTorch DataLoaders for both training and testing, defining a custom collate_fn
that moves tensors to the correct device:
9. Label Smoothing
A custom class LabelSmoothingCrossEntropy
is implemented to mitigate overconfidence in predictions. We replace the typical cross-entropy loss with label smoothing:
10. Improved Distillation Loss Function
The distillation_loss
function calculates:
- Hard Loss: computed via label smoothing (
label_smoothing_loss_fn
). - Soft Loss: a Kullback-Leibler divergence between teacher and student logits, scaled by a
temperature
.
The parameter alpha
balances between hard and soft losses.
11. Training Loop
We define train_student
to run through multiple epochs of training:
During each batch:
- The teacher model generates logits.
- The student model generates logits.
- We compute the distillation loss and backpropagate.
12. Model Evaluation
We generate text from both teacher and student models on the test set using:
We collect a few samples to compare their outputs.
13. Running the Training
Here, we call train_student(...)
with the desired parameters:
Adjust epochs or learning rate as needed.
14. Saving the Final Model
We save the student model's final state dictionary:
15-16. Comparison and Saving Outputs
We generate a few samples from both teacher and student and write them to a text file to visually compare performance.
Conclusion
You have now completed a step-by-step guide for distilling a GPT model (GPT2) into a smaller student (DistilGPT2). This approach balances teacher-derived knowledge (soft labels) with the original ground truth (hard labels) to produce a more compact yet capable model.
Feel free to modify:
- The data augmentation logic.
- The temperature and alpha parameters in the distillation loss.
- The dropout rates.
- Other hyperparameters (batch size, learning rate, epochs, etc.).
Experiment with these settings to find the best trade-off between performance and model size. Good luck distilling!