048创建自定义模型
创建自定义模型¶
🤗 Transformers 库设计得非常易于扩展。每个模型都在仓库的一个子文件夹中完全编码,没有任何抽象,因此你可以轻松地复制一个模型文件并根据需要进行调整。
如果你要编写一个全新的模型,从头开始可能会更容易。在这个教程中,我们将展示如何编写一个自定义模型及其配置,以便在 Transformers 中使用,并说明如何与社区分享这个模型(包括它依赖的代码),即使它不在 🤗 Transformers 库中。我们将基于 Transformers 框架扩展功能,加入你自己的钩子和自定义代码。
为了说明这些步骤,我们将以 ResNet 模型为例,通过包装 timm 库 中的 ResNet 类,使其符合 PreTrainedModel 的规范。
编写自定义配置¶
在深入模型之前,我们先来编写模型的配置。模型的配置对象将包含构建模型所需的所有必要信息。正如我们将在下一节中看到的,模型在初始化时只能接受一个 config 对象,因此我们需要确保这个对象尽可能完整。
transformers 库中的模型通常遵循这样的约定:它们在 __init__ 方法中接受一个 config 对象,然后将整个 config 传递给模型的子层,而不是将 config 对象拆分成多个参数分别传递给子层。以这种方式编写模型可以简化代码,确保任何超参数都有一个明确的“唯一来源”,并且也更容易重用 transformers 中其他模型的代码。
在我们的示例中,我们将取 ResNet 类的一些参数进行调整。不同的配置将给我们带来不同类型的 ResNets。我们检查一些参数的有效性后,将这些参数存储起来。