TensorFlow中的模型保存文件
文件构成
由TensorFlow保存的训练模型文件由四个文件组成:
1 | . |
每个文件的内容为:
文件 | 描述 |
---|---|
checkpoint | 指示文件夹中多个不同训练结果的属性,即如果在训练过程中保存了多次相同模型,在checkpoint文件中会保留每次保存的对应文件名 |
data | 保存模型的中参数的值 |
index | 保存模型中参数的名称和维度,相当于将模型中的参数名称和参数值关联起来 |
meta | 保存计算图 |
模型保存与加载
通过以下语句可以实现当前模型的保存:
1 | init = tf.global_variables_initializer() |
需要注意的是,模型保存必须在调用Session()
以及进行模型初始化之后进行,因为计算图是保存在会话中的,save()
函数必须传入相应的会话才能获取需要保存哪些参数;而TensorFlow要求不能保存未初始化的对象,在声明部分实际只对参数进行了定义,但没有run()
就没有赋值,没有实际计算也并不占用内存,所以理论上来讲没有初始化的变量是没有值的。
通过以下语句可以实现模型的加载:
1 | saver = tf.train.Saver() |
模型加载过程可以理解成“查字典”,即以当前需要恢复的变量去模型文件中检索,然后恢复同名项。其中模型文件中多余的变量会被忽略不计,也不会报错。
模型参数使用
TensorFlow在模型中要求每一个变量都有唯一的全局名称,这在一个模型自身的保存于恢复中是方便的,但当我们需要将一个模型中的参数运用到其他模型中时处理起来会相对麻烦些。这里列举一些方便的处理方式。
参数文件检查
想要检查 一个模型中有哪些参数,使用
1 | tf.contrib.framework.list_variables('dir_to_model') |
这个函数返回一个tuple的list,包含所有参数的名称和shape
1 | [('discriminator/discriminator_unit/disblock_6/conv_1/kernel', [3, 3, 256, 512])] |
而通过
1 | var = tf.contrib.framework.load_variable("dir_to_model", var_name) |
可以完成模型中某个参数的加载,此处只有加载,并没有对参数进行定义,即可以通过var去定义模型要使用的其他参数。
参数名称的重映射
TensorFlow中允许将旧的名称重映射到新的名称以便重新加载,使用语句
1 | saver = tf.train.Saver(var_list={'old_name': new_var}) |
可以将救名称与新变量进行关联,因此维护一个字典即可实现模型的迁移加载。但这样做的麻烦之处也是显而易见的——必须列举所有要恢复的变量才能进行正确的参数加载工作。
参数名称的修改
这里放一份代码,可以批量修改模型中参数的名字,这在处理模型加载过程中的参数重名问题时尤为有效。
tensorflow_rename_variables.py
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Flymin's Blog!