Understand Multi-Head Attention in Deep Learning – Deep Learning Tutorial

By | March 15, 2021

Multi-Head Attention is very popular in nlp. However, there also exists some problems in it. In this tutorial, we will discuss how to implement it in tensorflow.

Multi-Head Attention

If we plan to use 8 heads, Multi-Head Attention can be defined as:

Multi-Head Attention

Here each head attention is computed as:

\(Attention(Q_i,K_i,V_i) = softmax(\frac{Q_iK_i^T}{\sqrt{d}})V_i\)

where \(d\) is the dimension of \(Q\), \(K\) and \(V\).

For example, if we use 8 heads, the dimension of \(Q\), \(K\) and \(V\) is 512, each head will be 64 dimension.

In order to implement multi-head attention in tensorflow, we should notice:

  • \(Q\), \(K\) and \(V\) are input tensors, they can be the same or not.
  • As to weight \(W_i^Q\), \(W_i^K\) and \(W_i^V\), they are different in each head. For example, if you plan to use 8 heads, there will be 3 * 8 = 24 weights.

The structure of Multi-Head Attention is:

The structure of Multi-Head Attention

Here is an example to implement multi-head attention in tensorflow.

        #[batch_size,input_length,hidden_size*2], for example:64 * 50 * 200
        outputs = tf.concat([forward_output, backward_output], axis=2)

        # add multihead self-attention
        result_list = []
        factor = tf.sqrt(tf.constant(self.hidden_size,dtype = tf.float32)) # d in multi-head attention
        for k in range(self.head_num):
            tmp_str = 'head_' + str(k+1)
            with tf.variable_scope(tmp_str): # create weight for each head
                w_p = tf.Variable(tf.truncated_normal([self.hidden_size*2, self.hidden_size], stddev=0.1),name = 'w_p')
                b_p = tf.Variable(tf.zeros(self.hidden_size),name = 'b_p')

                # During training, we should calculate the attention for each sample in the batch
                ind = tf.constant(0)
                output_ta = tf.TensorArray(dtype=tf.float32, size=self.vary_batch_size)

                def cond(ind,output_ta):
                    return ind < self.vary_batch_size

                def body(ind,output_ta):
                    #[input_length,hidden_size*2]
                    single  = outputs[ind,:,:]
                    #[input_length,hidden_size]
                    single = tf.matmul(single,w_p) + b_p
                    #[input_length,input_length]
                    #soft_out = tf.nn.softmax( tf.matmul(a = single,b = single,transpose_b=True) / factor, axis = 1 )
                    soft_out = tf.nn.softmax( tf.matmul(a = single,b = single,transpose_b=True) / factor, dim = 1 ) # dim for tf 1.3.0
                    #[input_length,hidden_size]
                    att_out = tf.matmul(soft_out,single)
                    output_ta = output_ta.write(ind,att_out)

                    # increment
                    ind = ind + 1

                    return ind,output_ta

                _,final_output_ta = tf.while_loop(cond,body,[ind,output_ta])
                #[batch_size,input_length,hidden_size]
                single_output = final_output_ta.stack()
                print(type(single_output))
                print(single_output.get_shape())
                
            result_list.append(single_output)

        #[batch_size,input_length,hidden_size * head_num]
        new_outputs = tf.concat(result_list,axis = 2)

Leave a Reply