This UserWarning error is:
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
“Converting sparse IndexedSlices to a dense Tensor of unknown shape. “
In this tutorial, we will introduce how to fix this UserWarning when you are using tensorflow to build a model.
Why does this UserWarning occur?
If you are using tf.gather() or tf.nn.embedding_lookup() to get a tensor, you may encounter it.
How to fix this UserWarning?
Some programmers may tell you to pass a tf.Variable to tf.gather() or tf.nn.embedding_lookup() to fix this warning.
For example:
To fix this problem, you should try to ensure that the params input to tf.gather() (or the params inputs to tf.nn.embedding_lookup()) is a tf.Variable.
https://stackoverflow.com/questions/35892412/tensorflow-dense-gradient-explanation
However, this way can not fix it.
As to us, we have used tf.nn.embedding_lookup() in a tf.while_loop().
The code is below:
def _g_recurrence(i, x_t, h_tm1, gen_o): h_t = self.g_recurrent_unit(x_t, h_tm1) o_t = self.g_output_unit(h_t) # batch x 200 gen_o = gen_o.write(i, o_t) i_next = tf.where(tf.less(i, self.time_step-1), i+1, self.time_step-1) x_t_next = tf.nn.embedding_lookup(self.inputs,i_next) #batch x emb_dim return i+1, x_t_next, h_t, gen_o
We will get a tensor from self.inputs by i_next, i_next is a tf.Variable. However, we also get this UserWarning.
A good way to fix this UserWarning is to convert self.inputs to a TensorArray, then we can read tensor by TensorArray.read() function.
To convert a tensor to tensorarray, you can read:
Best Practice to Convert a Tensor to TensorArray in TensorFlow
Here is an example:
self.inputs_ta = tf.TensorArray(dtype=tf.float32, size=self.time_step , dynamic_size=False, infer_shape=True) self.inputs_ta = self.inputs_ta.unstack(self.inputs)
Then we can fix _g_recurrence() function as following:
def _g_recurrence(i, h_tm1, gen_o): #x_t = tf.nn.embedding_lookup(self.inputs,i) #batch x emb_dim x_t = self.inputs_ta.read(i) h_t = self.g_recurrent_unit(x_t, h_tm1) o_t = self.g_output_unit(h_t) gen_o = gen_o.write(i, o_t) return i+1, h_t, gen_o
Then we find this UserWarning disappeared.