In tensorflow, we can use saver.save() to save a model. Meanwhile, we also can use saver.restore() to load a model. How to restore a tensorflow model correctly? In this tutorial, we will introduce you how to do.
An example of tensorflow model
How to load tensoflow model using saver.restore()?
It is very easy to load a tensorflow model, here is an exmaple:
model_path = "./model_finetune/pretrained_model/" g = tf.Graph() g.seed = 1 with g.as_default(): seed_util.set_global_determinism() session_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False ) session_config.gpu_options.allow_growth = True sess = tf.Session(config=session_config, graph=g) with sess.as_default(): # load model checkpoint_file = tf.train.latest_checkpoint(model_path) saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file)
However, if you get tf.train.latest_checkpoint() FindFirstFile failed error, you can read this solution:
Fix TensorFlow tf.train.latest_checkpoint() FindFirstFile failed Error – TensorFlow Tutorial
However, we may use sess.run(tf.global_variables_initializer()) to initialize all global variables in tensorflow. If you use saver.restore() to load a model, you must load model after sess.run(tf.global_variables_initializer()).
Because saver.restore() will load all variables in a model to your tensorflow application. However, if you have loaded, and call sess.run(tf.global_variables_initializer()), this code will initialize all variables you have loaded. The values in all loaded variables will be replaced by randomized values. It means your loaded model will not work correctly.
You should:
1. sess.run(tf.global_variables_initializer())
2. Load model using saver.restore()
You can not do:
1. Load model using saver.restore()
2. sess.run(tf.global_variables_initializer())
If you only want to load a model and do not create any new variables, you do not need to call sess.run(tf.global_variables_initializer()).