Steps to Load TensorFlow Model Using saver.restore() Correctly – TensorFlow Tutorial

By | September 1, 2021

In tensorflow, we can use saver.save() to save a model. Meanwhile, we also can use saver.restore() to load a model. How to restore a tensorflow model correctly? In this tutorial, we will introduce you how to do.

An example of tensorflow model

An example of tensorflow model

How to load tensoflow model using saver.restore()?

It is very easy to load a tensorflow model, here is an exmaple:

model_path = "./model_finetune/pretrained_model/"
g = tf.Graph()
g.seed = 1
with g.as_default():
    seed_util.set_global_determinism()
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False
    )
    session_config.gpu_options.allow_growth = True
    sess = tf.Session(config=session_config, graph=g)

    with sess.as_default():
        # load model
        checkpoint_file = tf.train.latest_checkpoint(model_path)
        saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
        saver.restore(sess, checkpoint_file)

However, if you get tf.train.latest_checkpoint() FindFirstFile failed error, you can read this solution:

Fix TensorFlow tf.train.latest_checkpoint() FindFirstFile failed Error – TensorFlow Tutorial

However, we may use sess.run(tf.global_variables_initializer()) to initialize all global variables in tensorflow. If you use saver.restore() to load a model, you must load model after sess.run(tf.global_variables_initializer()).

Because saver.restore() will load all variables in a model to your tensorflow application. However, if you have loaded, and call sess.run(tf.global_variables_initializer()), this code will initialize all variables you have loaded. The values in all loaded variables will be replaced by randomized values. It means your loaded model will not work correctly.

You should:

1. sess.run(tf.global_variables_initializer())

2. Load model using saver.restore()

You can not do:

1. Load model using saver.restore()

2. sess.run(tf.global_variables_initializer())

If you only want to load a model and do not create any new variables, you do not need to call sess.run(tf.global_variables_initializer()).

Leave a Reply