TinyBERT:模型小7倍,速度快8倍,华中科大、华为出品( 五 )

研究者还执行了嵌入层的蒸馏,与基于隐状态的蒸馏类似,定义如下:

其中矩阵 E^S 和 H^T 分别表示 student 和 teacher 网络的嵌入。在论文中,这两种矩阵的形状与隐状态矩阵相同。矩阵 W_e 表示线性变化,它起到与 W_h 类似的作用。

除了模拟中间层的行为之外,研究者还利用知识蒸馏来拟合 teacher 模型的预测结果。具体而言,他们对 student 网络 logits 和 teacher 网络 logits 之间的 soft 交叉熵损失进行惩罚:

其中 z^S 和 z^T 分别表示 student 和 teacher 模型预测的 logits 向量,log_softmax() 表示 log 似然,t 表示温度值。实验表明,t=1 时运行良好。

通过以上几个蒸馏目标函数(即方程式 7、8、9 和 10),可以整合 teacher 和 student 网络之间对应层的蒸馏损失:

在实验中,研究者首先执行的是中间层蒸馏(M ≥ m ≥ 0),其次是预测层蒸馏(m = M + 1)。

两段式学习框架

BERT 的应用通常包含两个学习阶段:预训练和微调。BERT 在预训练阶段学到的大量知识非常重要,并且迁移的时候也应该包含在内。因此,研究者提出了一个两段式学习框架,包含通用蒸馏和特定于任务的蒸馏,如下图 2 所示:

推荐阅读