Attention mechanism has been widely used in deep learning, such as data mining, sentiment analysis and machine translation. No matter what strategy of attention, you must implement a attention visualization to compare in different models.
In this tutorial, we will tell you how to implement attention visualization using python.
Step 1: Install seaborn
pip install seaborn
Step 2: Implement attention visualization
If you have two models, each of them gets a attention value on the same sentence.
For example:
As to sentence: shit, this food is very disappointment.
Attention value of Model A is:0.3276, 0.0003, 0.0009, 0.0000, 0.0010, 0.0192, 0.6497, 0.0013
Attention value of Model B is: 0.0184, 0.0000, 0.0005, 0.0000, 0.0000, 0.0000, 0.9810, 0.0000
To display difference between them with a graph, you can use example code below:
import matplotlib.pyplot as plt import seaborn as sns import pandas as pd sns.set() data_word = ['shit',',','this','food','is','very','disappointment','.'] data_att = [[0.3276,0.0003,0.0009,0.0000,0.0010,0.0192,0.6497,0.0013], [0.0184,0.0000,0.0005,0.0000,0.0000,0.0000,0.9810,0.0000] ] d = pd.DataFrame(data = data_att,index = data_index, columns=data_word) f, ax = plt.subplots(figsize=(6,2)) sns.heatmap(d, vmin=0, vmax=1.0, ax=ax, cmap="OrRd") label_y = ax.get_yticklabels() plt.setp(label_y, rotation=360, horizontalalignment='right') label_x = ax.get_xticklabels() plt.setp(label_x, rotation=45, horizontalalignment='right') plt.show()
The result looks like this: