084使用 TensorFlow 在 TPU 上进行训练
使用 TensorFlow 在 TPU 上进行训练¶
如果你不想阅读长篇大论,只想直接查看 TPU 的代码示例,可以查看我们的 TPU 示例笔记本!
什么是 TPU?¶
TPU 是 张量处理单元(Tensor Processing Unit)。它们是由谷歌设计的硬件,用于大幅加速神经网络中的张量计算,类似于 GPU。TPU 可以用于网络训练和推理。它们通常通过谷歌的云服务访问,但在 Google Colab 和 Kaggle Kernels 中也可以免费访问小型 TPU。
因为 所有 🤗 Transformers 中的 TensorFlow 模型都是 Keras 模型,本文档中的大多数方法也适用于任何 Keras 模型的 TPU 训练!不过,有些内容是特定于 HuggingFace 的 Transformers 和 Datasets 生态系统的,我们会在需要时特别指出这些内容。
可以使用哪些类型的 TPU?¶
新用户经常对不同类型的 TPU 和访问方式感到困惑。首先需要了解的关键区别是 TPU 节点(TPU Node) 和 TPU 虚拟机(TPU VM) 之间的区别。
当你使用 TPU 节点 时,实际上是间接访问远程 TPU。你需要一个单独的虚拟机来初始化网络和数据管道,然后将它们转发到远程节点。在 Google Colab 中使用 TPU 时,实际上是在使用 TPU 节点。
使用 TPU 节点可能会有一些令人意外的行为,特别是对于不熟悉它的人!具体来说,由于 TPU 位于与运行 Python 代码的机器物理上不同的系统中,因此你的数据不能保存在本地机器上——从本地机器内部存储加载数据的数据管道将完全失败!相反,数据必须存储在 Google Cloud Storage 中,以便数据管道在远程 TPU 节点上运行时仍能访问这些数据。
如果你可以将所有数据存储在内存中作为 np.ndarray 或 tf.Tensor,则可以在使用 Colab 或 TPU 节点时直接调用 fit() 而不需要上传数据到 Google Cloud Storage。
🤗 Hugging Face 小贴士 🤗: 方法 Dataset.to_tf_dataset() 及其高级包装器 model.prepare_tf_dataset() 在 TPU 节点上会失败。原因是虽然它们创建了一个 tf.data.Dataset,但它不是一个“纯” tf.data 管道,而是使用 tf.numpy_function 或 Dataset.from_generator() 从底层的 HuggingFace Dataset 流式传输数据。这个 HuggingFace Dataset 存储在本地磁盘上,远程 TPU 节点无法读取。
第二种访问 TPU 的方式是通过 TPU 虚拟机。使用 TPU 虚拟机时,你可以直接连接到 TPU 附带的机器,就像在 GPU 虚拟机上训练一样。TPU 虚拟机通常更容易使用,尤其是在处理数据管道方面。上述所有警告都不适用于 TPU 虚拟机!
这是一个有偏见的文档,所以这里是我们的一点意见:如果可能,尽量避免使用 TPU 节点。 它比 TPU 虚拟机更令人困惑且更难调试。未来 TPU 节点也很可能不再被支持——谷歌最新的 TPU,TPUv4,只能作为 TPU 虚拟机访问,这表明 TPU 节点越来越成为一个“遗留”访问方法。然而,我们理解唯一免费的 TPU 访问是在 Colab 和 Kaggle Kernels 上,这些平台使用 TPU 节点——所以我们将会解释如何处理这个问题!请查看 TPU 示例笔记本 获取更详细的代码示例。
有哪些大小的 TPU 可供使用?¶
单个 TPU(如 v2-8/v3-8/v4-8)运行 8 个副本。TPU 以 吊舱(pod) 形式存在,可以同时运行数百或数千个副本。当使用的 TPU 数量超过单个 TPU 但少于整个吊舱时(例如 v3-32),你的 TPU 集群称为 吊舱切片(pod slice)。
通过 Colab 免费访问 TPU 时,通常会获得一个单个 v2-8 TPU。
我经常听说 XLA 这个东西。什么是 XLA?它与 TPU 有何关系?¶
XLA 是一种优化编译器,由 TensorFlow 和 JAX 使用。在 JAX 中它是唯一的编译器,而在 TensorFlow 中是可选的(但在 TPU 上是强制的)。最简单的方式来启用 XLA 是在调用 model.compile() 时传递参数 jit_compile=True。如果你没有遇到错误且性能良好,那就说明你已经准备好在 TPU 上运行了!
在 TPU 上调试通常比在 CPU/GPU 上更困难,因此我们建议先在 CPU/GPU 上使用 XLA 运行代码,然后再尝试在 TPU 上运行。当然,你不必训练很长时间——只需跑几步以确保模型和数据管道按预期工作即可。
XLA 编译后的代码通常更快——即使你不打算在 TPU 上运行,添加 jit_compile=True 也可以提高性能。不过,请注意以下关于 XLA 兼容性的注意事项!
经验之谈: 尽管使用 jit_compile=True 是提高速度并测试 CPU/GPU 代码是否 XLA 兼容的好方法,但实际在 TPU 上训练时将其保留可能会导致很多问题。XLA 编译将在 TPU 上隐式进行,因此记得在实际运行代码时删除这一行!
如何使我的模型 XLA 兼容?¶
在许多情况下,你的代码可能已经是 XLA 兼容的!但是,有些在普通 TensorFlow 中有效的方法在 XLA 中无效。我们将其归纳为以下三条核心规则:
🤗 Hugging Face 小贴士 🤗: 我们投入了很多努力重写了 TensorFlow 模型和损失函数,使其 XLA 兼容。我们的模型和损失函数默认遵循规则 #1 和 #2,因此如果你使用的是 transformers 模型,可以跳过这两条规则。但在编写自己的模型和损失函数时不要忘记这些规则!
XLA 规则 #1:代码中不能有“数据依赖的条件语句”¶
这意味着任何 if 语句都不能依赖于 tf.Tensor 内部的值。例如,以下代码块无法用 XLA 编译: