053导出到 TorchScript
导出到 TorchScript¶
这是我们对 TorchScript 进行实验的起点,我们仍在探索它在处理可变输入大小模型方面的能力。这是我们感兴趣的重点,未来我们将进一步深入分析,并提供更多代码示例、更灵活的实现以及与基于 Python 的代码相比的性能基准测试。
根据 TorchScript 文档:
TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。
有两个 PyTorch 模块,JIT 和 TRACE,允许开发者导出他们的模型以供其他程序(如效率导向的 C++ 程序)重用。
我们提供了一个接口,允许您将 🤗 Transformers 模型导出到 TorchScript,以便在不同的环境中重用这些模型,而不仅仅是基于 PyTorch 的 Python 程序。在这里,我们将解释如何导出和使用这些模型。
导出模型需要两件事:
- 使用
torchscript标志实例化模型 - 使用虚拟输入进行前向传递
这些要求意味着开发者需要注意以下几点。
使用 torchscript 标志和绑定权重¶
torchscript 标志是必需的,因为大多数 🤗 Transformers 语言模型在其 Embedding 层和 Decoding 层之间有绑定权重。TorchScript 不允许导出具有绑定权重的模型,因此需要预先解开并克隆这些权重。
使用 torchscript 标志实例化的模型,其 Embedding 层和 Decoding 层是分开的,这意味着它们不应再进行训练。继续训练会导致这两个层不同步,从而产生意外结果。
对于没有语言模型头的模型,情况并非如此,因为这些模型没有绑定权重。这些模型可以安全地导出,而无需使用 torchscript 标志。
虚拟输入和标准长度¶
虚拟输入用于模型的前向传递。当输入值通过各层传播时,PyTorch 会记录每个张量上执行的不同操作。这些记录的操作用于创建模型的 trace。
trace 是相对于输入维度创建的,因此受虚拟输入维度的限制,对于任何其他序列长度或批量大小都不起作用。尝试使用不同大小的输入时,会引发以下错误: