Distilling GPT with PyTorch and Transformers

This tutorial explains how to distill GPT2 into a smaller distilgpt2 model. The entire pipeline includes:

  1. Setting up a stable training environment.
  2. Loading teacher (gpt2) and student (distilgpt2) models.
  3. Using data augmentation (if desired).
  4. Applying improved preprocessing steps.
  5. Implementing label smoothing.
  6. Using a custom distillation loss function.
  7. Training and evaluating the distilled model.

Throughout the tutorial, you'll see how to train the student model using teacher outputs (knowledge distillation) while also retaining some direct supervision from ground-truth labels (hard loss). This balance is configured with a parameter alpha.

Prerequisites

  • PyTorch (for building and training neural networks)
  • Hugging Face Transformers (for pretrained transformer models)
  • Datasets (to load and process datasets)
  • Basic familiarity with Python and deep learning concepts

Ensure you have installed these packages before proceeding:

pip install torch transformers datasets

Full Code

Below is the complete code for the distillation workflow. You can save it in a file named Distill_GPT.py (or any name you prefer) and run it. Comments in the code are partly in Chinese, but each section is explained in English below the code blocks.

Click to expand the code

Explanation of Key Steps

1. Device Setup

We detect whether a GPU is available using torch.cuda.is_available():

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

This ensures our model will train on GPU if available, otherwise on CPU.

2. Choosing Stable Base Models

We define the teacher and student models:

teacher_model_name = "gpt2"
student_model_name = "distilgpt2"

Feel free to switch them to other models (like GPT-Neo or GPT-J) if you wish, as long as they match a causal language modeling architecture.

3. Loading Models in FP32 & Adjusting Dropout

We load both teacher and student in FP32 for more stable training. We also increase the dropout for the student model to reduce overfitting:

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name, 
    device_map="auto", 
    torch_dtype=torch.float32
).eval()

student_model = AutoModelForCausalLM.from_pretrained(
    student_model_name, 
    device_map="auto", 
    torch_dtype=torch.float32
)
# Increase dropout
student_model.config.attn_pdrop = 0.1
...

4. Loading the Tokenizer

We use the same tokenizer as our teacher model for consistency:

tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token

We set pad_token to eos_token to ensure all sequences are padded consistently.

5. Data Augmentation

Here, you can optionally include text augmentation techniques (synonym replacements, random insertions, etc.). Currently, this function returns the original text.

6. Improved Data Preprocessing

We perform minimal text cleaning and tokenization:

def tokenize_function(examples):
    ...
    tokenized = tokenizer(
        processed_texts, 
        truncation=True, 
        max_length=128, 
        padding="max_length", 
        return_tensors="pt"
    )
    ...

We remove texts that are too short (less than 10 characters). This helps avoid training with trivial examples.

7. Loading & Splitting the OpenWebText Dataset

dataset = load_dataset("openwebtext")
dataset = dataset["train"].train_test_split(test_size=0.1)

The dataset is split into training (90%) and testing (10%). We then map our tokenize_function to tokenize the entire dataset.

8. DataLoader Configuration

We create PyTorch DataLoaders for both training and testing, defining a custom collate_fn that moves tensors to the correct device:

def collate_fn(batch):
    ...
train_loader = DataLoader(..., collate_fn=collate_fn)

9. Label Smoothing

A custom class LabelSmoothingCrossEntropy is implemented to mitigate overconfidence in predictions. We replace the typical cross-entropy loss with label smoothing:

class LabelSmoothingCrossEntropy(nn.Module):
    ...
label_smoothing_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.1)

10. Improved Distillation Loss Function

The distillation_loss function calculates:

  1. Hard Loss: computed via label smoothing (label_smoothing_loss_fn).
  2. Soft Loss: a Kullback-Leibler divergence between teacher and student logits, scaled by a temperature.
def distillation_loss(student_outputs, teacher_outputs, labels, temperature=3.0, alpha=0.7):
    ...
    return alpha * hard_loss + (1 - alpha) * soft_loss

The parameter alpha balances between hard and soft losses.

11. Training Loop

We define train_student to run through multiple epochs of training:

def train_student(student_model, teacher_model, train_loader, epochs=3, lr=3e-6, output_dir="distilled_model"):
    ...

During each batch:

  1. The teacher model generates logits.
  2. The student model generates logits.
  3. We compute the distillation loss and backpropagate.

12. Model Evaluation

We generate text from both teacher and student models on the test set using:

model.generate(...)

We collect a few samples to compare their outputs.

13. Running the Training

Here, we call train_student(...) with the desired parameters:

trained_model = train_student(
    student_model,
    teacher_model,
    train_loader,
    epochs=2, 
    lr=3e-6,
    output_dir=output_model_dir
)

Adjust epochs or learning rate as needed.

14. Saving the Final Model

We save the student model's final state dictionary:

torch.save(trained_model.state_dict(), os.path.join(output_model_dir, "final_distilled_model.pt"))

15-16. Comparison and Saving Outputs

We generate a few samples from both teacher and student and write them to a text file to visually compare performance.

Conclusion

You have now completed a step-by-step guide for distilling a GPT model (GPT2) into a smaller student (DistilGPT2). This approach balances teacher-derived knowledge (soft labels) with the original ground truth (hard labels) to produce a more compact yet capable model.

Feel free to modify:

  • The data augmentation logic.
  • The temperature and alpha parameters in the distillation loss.
  • The dropout rates.
  • Other hyperparameters (batch size, learning rate, epochs, etc.).

Experiment with these settings to find the best trade-off between performance and model size. Good luck distilling!