知识蒸馏在计算机视觉中的应用¶
知识蒸馏是一种将大型复杂模型(教师模型)的知识转移到小型简单模型(学生模型)的技术。为了从一个模型中提取知识并传递给另一个模型,我们首先使用一个预训练的教师模型(例如图像分类任务),然后随机初始化一个学生模型来学习相同的任务。接下来,我们训练学生模型,使其输出与教师模型的输出尽可能接近,从而模拟教师模型的行为。这一技术最早由 Hinton 等人在论文 《神经网络中的知识蒸馏》 中提出。在本指南中,我们将进行特定任务的知识蒸馏。我们将使用 beans 数据集。
本指南展示了如何使用 🤗 Transformers 的 Trainer API 将一个经过微调的 ViT 模型(教师模型)蒸馏到一个 MobileNet(学生模型)。
让我们安装蒸馏和评估过程中需要的库。
pip install transformers datasets accelerate tensorboard evaluate --upgrade
在这个例子中,我们将使用 merve/beans-vit-224 模型作为教师模型。这是一个基于 google/vit-base-patch16-224-in21k 在 beans 数据集上微调的图像分类模型。我们将把这个模型蒸馏到一个随机初始化的 MobileNetV2。
现在我们加载数据集。
from datasets import load_dataset
dataset = load_dataset("beans")
我们可以使用任一模型的图像处理器,因为在这种情况下它们返回相同分辨率的相同输出。我们将使用 dataset 的 map() 方法对数据集的每个分片进行预处理。
from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")
def process(examples):
processed_inputs = teacher_processor(examples["image"])
return processed_inputs
processed_datasets = dataset.map(process, batched=True)
我们的目标是让学生模型(随机初始化的 MobileNet)模仿教师模型(微调后的视觉变换器)。为了实现这一点,我们首先获取教师模型和学生模型的 logits 输出。然后,我们将每种输出除以参数 temperature,该参数控制每个软目标的重要性。参数 lambda 用于权衡蒸馏损失的重要性。在本例中,我们将使用 temperature=5 和 lambda=0.5。我们将使用 Kullback-Leibler 散度损失来计算学生模型和教师模型之间的差异。给定两个数据 P 和 Q,KL 散度解释了用 Q 表示 P 所需的额外信息量。如果两者完全相同,它们的 KL 散度为零,因为不需要其他信息来解释 P。因此,在知识蒸馏的背景下,KL 散度是有用的。
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageDistilTrainer(Trainer):
def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
super().__init__(model=student_model, *args, **kwargs)
self.teacher = teacher_model
self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.teacher.to(device)
self.teacher.eval()
self.temperature = temperature
self.lambda_param = lambda_param
def compute_loss(self, student, inputs, return_outputs=False):
student_output = self.student(**inputs)
with torch.no_grad():
teacher_output = self.teacher(**inputs)
# 计算教师和学生的软目标
soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)
# 计算损失
distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)
# 计算真实标签的损失
student_target_loss = student_output.loss
# 计算最终损失
loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
return (loss, student_output) if return_outputs else loss
现在我们登录 Hugging Face Hub,以便通过 Trainer 将模型推送到 Hugging Face Hub。
from huggingface_hub import notebook_login
notebook_login()
设置 TrainingArguments、教师模型和学生模型。
from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification
training_args = TrainingArguments(
output_dir="my-awesome-model",
num_train_epochs=30,
fp16=True,
logging_dir=f"{repo_name}/logs",
logging_strategy="epoch",
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="tensorboard",
push_to_hub=True,
hub_strategy="every_save",
hub_model_id=repo_name,
)
num_labels = len(processed_datasets["train"].features["labels"].names)
# 初始化模型
teacher_model = AutoModelForImageClassification.from_pretrained(
"merve/beans-vit-224",
num_labels=num_labels,
ignore_mismatched_sizes=True
)
# 从头开始训练 MobileNetV2
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)
我们可以使用 compute_metrics 函数在测试集上评估模型。此函数将在训练过程中计算模型的 accuracy 和 f1。
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
return {"accuracy": acc["accuracy"]}
现在我们使用定义的训练参数初始化 Trainer。我们还将初始化数据收集器。
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
student_model=student_model,
teacher_model=teacher_model,
training_args=training_args,
train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"],
data_collator=data_collator,
processing_class=teacher_processor,
compute_metrics=compute_metrics,
temperature=5,
lambda_param=0.5
)
现在我们可以训练模型。
trainer.train()
我们可以在测试集上评估模型。
trainer.evaluate(processed_datasets["test"])