085使用 PyTorch 在 Apple 芯片上训练模型
使用 PyTorch 在 Apple 芯片上训练模型¶
之前,在 Mac 上训练模型只能使用 CPU。随着 PyTorch v1.12 的发布,你现在可以利用 Apple 的硅基 GPU 进行模型训练,从而获得显著更快的性能和训练速度。这是通过在 PyTorch 中集成 Apple 的 Metal Performance Shaders (MPS) 后端实现的。MPS 后端 将 PyTorch 操作实现为自定义 Metal 着色器,并将这些模块放置在 mps 设备上。
目前,一些 PyTorch 操作尚未在 MPS 中实现,可能会抛出错误。为了避免这种情况,你可以设置环境变量 PYTORCH_ENABLE_MPS_FALLBACK=1,这样当遇到不支持的操作时会自动回退到 CPU 内核(你仍然会看到一个 UserWarning)。
如果你遇到其他错误,请在 PyTorch 仓库中提交问题,因为 Trainer 仅集成了 MPS 后端。
设置 mps 设备后,你可以:
- 在本地训练更大的网络或批量大小
- 减少数据检索延迟,因为 GPU 的统一内存架构允许直接访问完整的内存存储
- 节省成本,因为不需要在基于云的 GPU 或添加本地 GPU 上进行训练
首先确保你已经安装了 PyTorch。MPS 加速支持 macOS 12.3+。