Keras 模型的保存与载入
Keras使用h5py进行模型的保存和载入
保存
Keras内部Model继承Container
调用Container.save方法进行保存模型
载入
使用如下代码载入模型
1 | model = keras.models.load_model(model_path) |
如果模型存在自定义的函数(有的激活函数Keras没有内置,如atan)和层Layer,则使用CustomObjectScope,具体代码如下:
1 | from keras.utils import CustomObjectScope |
保存自定义的层
一般自定义的层会继承Layer,如果自定义的层存在额外的参数,如下代码:
1 | class MyLayer(layers.Layer): |
在存在额外参数的情况下,如果需要对自定义的层进行保存,则需要重写get_config方法,重写后的代码如下:
1 | def get_config(self): |
这样在模型保存的时候便可以把自定义层的参数保存下来啦,也可以正常载入~