保存和恢复模型

tensorflow-header

可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复,避免长时间的训练。此外,保存还意味着您可以分享您的模型,其他人可以重现您的工作。在发布研究模型和技术时,大多数机器学习从业者会分享:

  • 用于创建模型的代码
  • 模型的训练权重或形参

共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅 安全使用 TensorFlow 以了解详情。

选项

根据您使用的 API,可以通过不同的方式保存 TensorFlow 模型。本指南使用 tf.keras – 一种用于在 TensorFlow 中构建和训练模型的高级 API。建议使用本教程中使用的新的高级 .keras 格式来保存 Keras 对象,因为它提供了强大、高效的基于名称的保存,通常比低级或旧版格式更容易调试。如需更高级的保存或序列化工作流,尤其是那些涉及自定义对象的工作流,请参阅保存和加载 Keras 模型指南。对于其他方式,请参阅使用 SavedModel 格式指南

配置

安装并导入

安装并导入Tensorflow和依赖项:

获取示例数据集

为了演示如何保存和加载权重,您将使用 MNIST 数据集。为了加快运行速度,请使用前 1000 个样本:

定义模型

首先构建一个简单的序列(sequential)模型:

在训练期间保存模型(以 checkpoints 形式保存)

您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。tf.keras.callbacks.ModelCheckpoint 回调允许您在训练期间结束时持续保存模型。

Checkpoint 回调用法

创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint 回调:

这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:

只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。

现在,重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行(约 10% 的准确率):

然后从 checkpoint 加载权重并重新评估:

checkpoint 回调选项

回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。

训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :

现在,检查生成的检查点并选择最新检查点:

注:默认 TensorFlow 格式只保存最近的 5 个检查点。

要进行测试,请重置模型并加载最新检查点:

这些文件是什么?

上述代码可将权重存储到检查点格式文件(仅包含二进制格式训练权重) 的合集中。检查点包含:

  • 一个或多个包含模型权重的分片。
  • 一个索引文件,指示哪些权重存储在哪个分片中。

如果您在一台计算机上训练模型,您将获得一个具有如下后缀的分片:.data-00000-of-00001

手动保存权重

要手动保存权重,请使用 tf.keras.Model.save_weights。默认情况下,tf.keras(尤其是 Model.save_weights 方法)使用扩展名为 .ckpt 的 TensorFlow 检查点格式。要以扩展名为 .h5 的 HDF5 格式保存,请参阅保存和加载模型指南。

保存整个模型

调用 tf.keras.Model.save,将模型的架构、权重和训练配置保存在单个 model.keras zip 存档中。

整个模型可以保存为三种不同的文件格式(新的 .keras 格式和两种旧格式:SavedModel 和 HDF5)。将模型保存为 path/to/model.keras 会自动以最新格式保存。

注意:对于 Keras 对象,建议使用新的高级 .keras 格式进行更丰富的基于名称的保存和重新加载,这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。

您可以通过以下方式切换到 SavedModel 格式:

  • 将 save_format='tf' 传递到 save()
  • 传递不带扩展名的文件名

您可以通过以下方式切换到 H5 格式:

  • 将 save_format='h5' 传递到 save()
  • 传递以 .h5 结尾的文件名

Saving a fully-functional model is very useful—you can load them in TensorFlow.js (Saved ModelHDF5) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite (Saved ModelHDF5)

*Custom objects (for example, subclassed models or layers) require special attention when saving and loading. Refer to the Saving custom objects section below.

新的高级 .keras 格式

以 .keras 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式,它实现了基于名称的保存,从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易,并且它是 Keras 的推荐格式。

下面的部分说明了如何以 .keras 格式保存和恢复模型。

从 .keras zip 归档重新加载新的 Keras 模型:

尝试使用加载的模型运行评估和预测:

SavedModel 格式

SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 tf.keras.models.load_model 还原,并且与 TensorFlow Serving 兼容。SavedModel 指南详细介绍了如何 serve/inspect SavedModel。以下部分说明了保存和恢复模型的步骤。

SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录:

从保存的模型重新加载一个新的 Keras 模型:

使用与原始模型相同的实参编译恢复的模型。尝试使用加载的模型运行评估和预测:

HDF5 格式

Keras 使用 HDF5 标准提供基本的旧版高级保存格式。

现在,从该文件重新创建模型:

检查其准确率(accuracy):

Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容:

  • 权重值
  • 模型的架构
  • 模型的训练配置(您传递给 .compile() 方法的内容)
  • 优化器及其状态(如果有)(这样,您便可从中断的地方重新启动训练)

Keras 无法保存 v1.x 优化器(来自 tf.compat.v1.train),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。

保存自定义对象

如果您使用的是 SavedModel 格式,则可以跳过此部分。高级 .keras/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 .keras/HDF5 格式使用对象配置来保存模型架构,而 SavedModel 保存执行计算图。因此,SavedModels 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。但是,因此调试低级 SavedModels 可能会更加困难,鉴于基于名称并且对于 Keras 是原生的特性,我们建议改用高级 .keras 格式。

要将自定义对象保存到 .keras 和 HDF5,您必须执行以下操作:

  1. 在您的对象中定义一个 get_config 方法,并且可以选择定义一个 from_config 类方法。
    • get_config(self) 返回重新创建对象所需的形参的 JSON 可序列化字典。
    • from_config(cls, config) 使用从 get_config 返回的配置来创建一个新对象。默认情况下,此函数将使用配置作为初始化 kwarg (return cls(**config))。
  2. 通过以下三种方式之一将自定义对象传递给模型:
    • 使用 @tf.keras.utils.register_keras_serializable 装饰器注册自定义对象。(推荐)
    • 加载模型时直接将对象传递给 custom_objects 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
    • 将 tf.keras.utils.custom_object_scope 与 custom_objects 字典实参中包含的对象一起使用,并在作用域内放置一个 tf.keras.models.load_model(path){ /code2} 调用。

有关自定义对象和 get_config 的示例,请参阅从头开始编写层和模型教程。