Solve tf.svd NaN bug with np.linalg.svd- TensorFlow Example

By | May 29, 2019

If you use tf.svd() function to compute Singular Value Decomposition in TensorFlow, you may encouter NaN bug. For example:

s, u, v = tf.svd(sen_svd_input_u)

s is singular value of matrix sen_svd_input_v, however, some NaN is in variable s.

Here is a bug example from:https://nbviewer.jupyter.org/github/SciRuby/sciruby-notebooks/blob/master/getting_started.ipynb

To fix this debug, we can use np.linalg.svd() function to replace tf.svd() to compute svd.

Here is example code:

import tensorflow as tf;
import numpy as np

def replace_tf_svd_with_np_svd():
    """Replaces tf.svd with np.svd. Slow, but a workaround for tf.svd bugs.
    For details see
    https://github.com/tensorflow/tensorflow/issues/8905
    """
    if hasattr(tf, 'original_svd'):
    # This function has been already called and tf.svd is already replaced.
        return
    tf.original_svd = tf.svd

    def my_svd(tensor, full_matrices=False, compute_uv=True):
        dtype = tensor.dtype
        u, s, v = tf.py_func(np.linalg.svd, [tensor, full_matrices, compute_uv],[dtype, dtype, dtype])
        s_, u_, v_ = tf.original_svd(tensor, full_matrices, compute_uv)
        s = tf.reshape(s, tf.shape(s_))
        u = tf.reshape(u, tf.shape(u_))
        v = tf.reshape(v, tf.shape(v_))
        # Converting numpy order of v dims to TF order.
        order = range(tensor.get_shape().ndims)
        order[-2], order[-1] = order[-1], order[-2]
        v = tf.transpose(v, order)
        return s, u, v
    tf.svd = my_svd

A = tf.constant([[1,2,3],[1,3,3],[4,5,6],[7,8,9]], dtype=tf.float32)

#use tensorflow original tf.svd() to compute svd
s, u, v = tf.svd(A)
replace_tf_svd_with_np_svd()
#use np.linalg.svd to compute svd
s1, u1, v1 = tf.svd(A)

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
   
    s, u, v,s1, u1, v1= (sess.run([s, u, v,s1, u1, v1]))

    print 's='
    print s
    print 'u='
    print u
    print 'v='
    print v
    print 's1='
    print s1
    print 'u1='
    print u1
    print 'v1='
    print v1

The result is:

Leave a Reply