Implement Squashing Function in Capsule Network Using TensorFlow – TensorFlow Tutorial

By | February 16, 2021

In this tutorial, we will use tensorflow to implement squashing function in capsule network, you can use this example code to squash the value of a tensor.

Squashing Function in Capsule Network

In capsule network, a squashing function is defined as:

Squashing Function in Capsule Network

We will use tensorflow to implement this function.

How to implement squashing function in capsule network?

Here is a tensorflow example code.

import tensorflow as tf 
a = tf.constant(list(range(15)), dtype = tf.float32) 
c = tf.reshape(a, [3, 5])

# v is a 2 dims, such as 64*200

def squashing(v):
    u = tf.pow(v, 2)
    sum_ = tf.reduce_sum(u, axis = 1, keepdims = True)
  
    left_ = sum_ / (sum_ + 1.0) #64*1
    right_ = tf.nn.l2_normalize(v, axis = 1) #64*200
    out = left_ * right_
    return v, sum_, left_, right_, out

o = squashing(c) 

                
# Printing the result 
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(o))

In this example code, we will use squashing() function to squash the value of a tensor.

Understand TensorFlow tf.pow() with Examples: Compute the Power of the Tensor – TensorFlow Tutorial

In order to know how to use tf.nn.l2_normalize(), you can read:

Unit-normalize a TensorFlow Tensor: A Practice Guide – TensorFlow Tips

Run this code, we will get:

(array([[ 0.,  1.,  2.,  3.,  4.],
       [ 5.,  6.,  7.,  8.,  9.],
       [10., 11., 12., 13., 14.]], dtype=float32), array([[0.9677419 ],
       [0.99609375],
       [0.998632  ]], dtype=float32), array([[0.        , 0.18257418, 0.36514837, 0.5477226 , 0.73029673],
       [0.31311214, 0.37573457, 0.438357  , 0.5009794 , 0.56360185],
       [0.3701166 , 0.40712827, 0.44413993, 0.48115158, 0.51816326]],
      dtype=float32), array([[0.        , 0.17668469, 0.35336939, 0.5300541 , 0.70673877],
       [0.31188905, 0.37426686, 0.43664467, 0.49902248, 0.5614003 ],
       [0.36961028, 0.40657133, 0.44353235, 0.48049337, 0.51745445]],
      dtype=float32))

Leave a Reply