Only Initialize New Variables When Using an Existing Model for Fine-tuning – TensorFlow Tutorial

By | October 25, 2021

Sometimes, we often build our own new neural network layers based on an existing model in tensorflow. For example, we may build a BiLSTM layer on Bert model.

The structure may look like:

Building a BiLSTM Layer based on Bert Model

We may use tensorflow saver.restore() to load an existing tensorflow model. Here is the tutorial.

Steps to Load TensorFlow Model Using saver.restore() Correctly – TensorFlow Tutorial

However, if you have used saver.restore() to load a model, you have to do as follows:

(1) run sess.run(tf.global_variables_initializer())

(2) use saver.restore() to load

The reason why we must do like this is that sess.run(tf.global_variables_initializer()) will initialize all global variables, if you have loaded a model, all variables in this model will also be initialized.

You can read this tutorial for detail.

An Explain to sess.run(tf.global_variables_initializer()) for Beginners – TensorFlow Tutorial

If you do not fine-tune an existing model, you can do as above. However, if you plan to fine-tune an existing model, you may have to do as follow:

(1) load an existing tensorflow model

(2) get the output from the model and pass it to your own model

(3) train and fine-tune the whole model.

As to example above, the BiLSTM is our own model, it will get the output of Bert.

If you use saver.restore(), you may find that this function will be run before sess.run(tf.global_variables_initializer()). All variables in Bert will also be initialized.

In order to avoid this problem, we should: Only initialize variables in our own model.

How to only initialize variables in our own model?

We can do by following steps.

Step 1: Load an existing tensorflow model using saver.restore()

Here is an example:

    with sess.as_default():
        # load an existing model
        checkpoint_file = tf.train.latest_checkpoint(model_path)
        saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
        saver.restore(sess, checkpoint_file) #

Here checkpoint_file is the path of our exising model.

Step 2: Use a list to save all variables in loaded model

Here is the example code:

        variable_saved_value = [x for x in tf.global_variables()]
        print("variables in loaded model:",len(variable_saved_value))

variable_saved_value is a python list, which saves all variables in our loaded model.

Step 3: Create your own model

Then we can create our own model, which contains some variables that are not in our loaded model.

Here is an example:

        mix_model_fine_tune = ModelMixtureFineTune(
                                    graph = g,
                                    max_frames_num=FLAGS.max_frames_num,
                                    class_num=Dataset.n_class, 
                                    feature_dim=FLAGS.embedding_dim,
                                    hidden_dim=FLAGS.hidden_size)
        mix_model_fine_tune.build() # build your own model

Step 4: Get all global variables not in variable_saved_value

Here is the example:

        init_variables = [x for x in tf.global_variables() if x not in variable_saved_value]

Variable init_variables saves all variables you must initialize.

Step 5: Initialize all variables

You can use code below to implement:

       sess.run(tf.variables_initializer(init_variables))

Then you can start to train or fine-tune a tensorflow model.

Leave a Reply