1+ import matplotlib .pyplot as plt
2+ import tensorflow as tf
3+ def plot_attention_weights (attention , key , query ):
4+
5+ '''Attention visualisation for Transformer
6+
7+ Parameters
8+ ----------
9+ attention : attention weights
10+ shape of (1, number of head, length of key, length of query).
11+
12+ key : key for attention computation
13+ a list of values which would be shown as xtick labels
14+
15+ value : value for attention computation
16+ a list of values which would be shown as ytick labels
17+
18+ '''
19+
20+
21+ fig = plt .figure (figsize = (16 , 8 ))
22+
23+ attention = tf .squeeze (attention , axis = 0 )
24+
25+ for head in range (attention .shape [0 ]):
26+ ax = fig .add_subplot (attention .shape [0 ]// 2 , 2 , head + 1 )
27+ ax .matshow (attention [head ], cmap = 'viridis' )
28+ fontdict = {'fontsize' : 12 }
29+ ax .set_xticks (range (len (key )))
30+ ax .set_yticks (range (len (query )))
31+
32+ # ax.set_ylim(len(query)-1.5, -0.5)
33+ ax .set_xticklabels (
34+ [str (i ) for i in key ],
35+ fontdict = fontdict , rotation = 90 )
36+
37+ ax .set_yticklabels ([str (i ) for i in query ], fontdict = fontdict )
38+
39+ ax .set_xlabel ('Head {}' .format (head + 1 ), fontdict = fontdict )
40+ plt .tight_layout ()
41+ plt .show ()
0 commit comments