Tinyworker的技术小站

学而不思则罔,思而不学则殆,温故而知新。
知不弃行,行不离思,慎思之,笃行之。

Follow me on GitHub

保存与恢复

模型的训练进度可以在期间和之后保存,意味着可以在上次暂停的地方继续训练,这就避免了一次性训练时间过长的问题。同时,保存也意味着可以分享。

训练期间保存检查点

keras提供了用于保存模型训练节点的功能,使用callbacks.ModelCheckpoint来执行回调。

回调函数样例如下:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10,
          validation_data = (test_images,test_labels),
          callbacks = [cp_callback])  # pass callback to training

可以看出,在使用检查点时,需要配置保存文件路径,是否仅保存权重,以及存储模式,最后在fit方法中配置callbacks参数即可。

如果想要加载最新检查点,如下:

latest = tf.train.latest_checkpoint(checkpoint_dir)
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

如果想手动的保存,也是可以的,调用Model.save_weights方法:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

保存整个模型

模型可以完整的保存,包含其权重值,模型配置甚至优化器配置,这样模型的训练可以无需访问原始代码。keras对模型是生成一个HDF5标准的格式文件,可以视为一个二进制blob。

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')

new_model = keras.models.load_model('my_model.h5')
new_model.summary()

目前keras无法保存TensorFlow优化器(来自tf.train),在用这一类优化器时,需要加载模型后对其进行重新编译,使优化器状态变松散。

文件信息

keras的保存是将每次训练后的权重信息存储在检查点格式文件中,检查点包括:模型权重的一个或多个分片,指示哪些权重存储在哪些分片中的索引文件。

后续进阶

上述是tf.keras保存和加载模型的基本操作。

  • tf.keras的指南中详细介绍了如何使用该api保存和加载模型。
  • Eager Execution也提供了保存模型操作。
  • 低阶API也具备此功能