In this tutorial, we will use tensorflow to implement squashing function in capsule network, you can use this example code to squash the value of a tensor.
Squashing Function in Capsule Network
In capsule network, a squashing function is defined as:

We will use tensorflow to implement this function.
How to implement squashing function in capsule network?
Here is a tensorflow example code.
import tensorflow as tf
a = tf.constant(list(range(15)), dtype = tf.float32)
c = tf.reshape(a, [3, 5])
# v is a 2 dims, such as 64*200
def squashing(v):
u = tf.pow(v, 2)
sum_ = tf.reduce_sum(u, axis = 1, keepdims = True)
left_ = sum_ / (sum_ + 1.0) #64*1
right_ = tf.nn.l2_normalize(v, axis = 1) #64*200
out = left_ * right_
return v, sum_, left_, right_, out
o = squashing(c)
# Printing the result
init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
print(sess.run(o))
- import tensorflow as tf
- a = tf.constant(list(range(15)), dtype = tf.float32)
- c = tf.reshape(a, [3, 5])
- # v is a 2 dims, such as 64*200
- def squashing(v):
- u = tf.pow(v, 2)
- sum_ = tf.reduce_sum(u, axis = 1, keepdims = True)
-
- left_ = sum_ / (sum_ + 1.0) #64*1
- right_ = tf.nn.l2_normalize(v, axis = 1) #64*200
- out = left_ * right_
- return v, sum_, left_, right_, out
- o = squashing(c)
-
- # Printing the result
- init = tf.global_variables_initializer()
- init_local = tf.local_variables_initializer()
- with tf.Session() as sess:
- sess.run([init, init_local])
- print(sess.run(o))
import tensorflow as tf
a = tf.constant(list(range(15)), dtype = tf.float32)
c = tf.reshape(a, [3, 5])
# v is a 2 dims, such as 64*200
def squashing(v):
u = tf.pow(v, 2)
sum_ = tf.reduce_sum(u, axis = 1, keepdims = True)
left_ = sum_ / (sum_ + 1.0) #64*1
right_ = tf.nn.l2_normalize(v, axis = 1) #64*200
out = left_ * right_
return v, sum_, left_, right_, out
o = squashing(c)
# Printing the result
init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
print(sess.run(o))
In this example code, we will use squashing() function to squash the value of a tensor.
Understand TensorFlow tf.pow() with Examples: Compute the Power of the Tensor – TensorFlow Tutorial
In order to know how to use tf.nn.l2_normalize(), you can read:
Unit-normalize a TensorFlow Tensor: A Practice Guide – TensorFlow Tips
Run this code, we will get:
(array([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]], dtype=float32), array([[0.9677419 ],
[0.99609375],
[0.998632 ]], dtype=float32), array([[0. , 0.18257418, 0.36514837, 0.5477226 , 0.73029673],
[0.31311214, 0.37573457, 0.438357 , 0.5009794 , 0.56360185],
[0.3701166 , 0.40712827, 0.44413993, 0.48115158, 0.51816326]],
dtype=float32), array([[0. , 0.17668469, 0.35336939, 0.5300541 , 0.70673877],
[0.31188905, 0.37426686, 0.43664467, 0.49902248, 0.5614003 ],
[0.36961028, 0.40657133, 0.44353235, 0.48049337, 0.51745445]],
dtype=float32))
- (array([[ 0., 1., 2., 3., 4.],
- [ 5., 6., 7., 8., 9.],
- [10., 11., 12., 13., 14.]], dtype=float32), array([[0.9677419 ],
- [0.99609375],
- [0.998632 ]], dtype=float32), array([[0. , 0.18257418, 0.36514837, 0.5477226 , 0.73029673],
- [0.31311214, 0.37573457, 0.438357 , 0.5009794 , 0.56360185],
- [0.3701166 , 0.40712827, 0.44413993, 0.48115158, 0.51816326]],
- dtype=float32), array([[0. , 0.17668469, 0.35336939, 0.5300541 , 0.70673877],
- [0.31188905, 0.37426686, 0.43664467, 0.49902248, 0.5614003 ],
- [0.36961028, 0.40657133, 0.44353235, 0.48049337, 0.51745445]],
- dtype=float32))
(array([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]], dtype=float32), array([[0.9677419 ],
[0.99609375],
[0.998632 ]], dtype=float32), array([[0. , 0.18257418, 0.36514837, 0.5477226 , 0.73029673],
[0.31311214, 0.37573457, 0.438357 , 0.5009794 , 0.56360185],
[0.3701166 , 0.40712827, 0.44413993, 0.48115158, 0.51816326]],
dtype=float32), array([[0. , 0.17668469, 0.35336939, 0.5300541 , 0.70673877],
[0.31188905, 0.37426686, 0.43664467, 0.49902248, 0.5614003 ],
[0.36961028, 0.40657133, 0.44353235, 0.48049337, 0.51745445]],
dtype=float32))