We make up a toy problem to gain intuition about attenion heads¶

Let's use the following list of words (and a special token): , the, man, chicken, ordered,woman, beef 7 words,so V=7, and token ids are 1 to 7

This our corpus:

man ordered the chicken

woman ordered the beef

Exercise Goal: Let's examine the Attention Wt matrix and Output predictions (see end of notebook) to see how dependencies might be encoded

Look at the output prediction of "the -> chckn" or "the ->beef" Look at the 4th row (b/c 'the' is 4th input token) of the TxT attention weights. Look at the Q,K matrices to see what values are asymmetric between the two input cases

In [1]:
# ----------- IMPORT STATEMENTS ---------------
import argparse
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import os
import sys
import numpy as np
import time

#---------------------------------------------
print('import done')
#import pandas as pd
np.set_printoptions(precision=4)

#set the embedding/attention head size parameters
E = 7  #size of embedding layer, which in this notebook is same as V, size of vocabulary

H = 20  #size of attention head    #<<<<<<<<<---- try 10,20,40 ---
                                     # How does H parameter affect the Attent. Wts?  hint: look at model summary output
import done
In [2]:
epochs         = 5000
lrate          = 0.0001
torch.manual_seed(77)
Out[2]:
<torch._C.Generator at 0x155408249c10>

Set up sequence input¶

In [3]:
# Set up sequence input
if 1:
    colnames  =  ["<ST>", "the", "man","chkn","ordrd","woman","beef"]
    V    = len(colnames)
    #------ make a sequence of length T --------------
    sequence2use   = np.asarray([[1,3,5,2,4],[1,6,5,2,7]])  #start is first token, with array index 0,
    B,T            = sequence2use.shape
    sequence2pred  = np.zeros((B,T),dtype=int)
    for bi in range(B):
      sequence2pred[bi,0:-1] = sequence2use[bi,1:]
      sequence2pred[bi,-1]   = sequence2use[bi,0]

#set up a dataframe info for nice printouts later
rownamesxb =list(list())
for bi in range(B):
    rownames=list()
    for i in range(T):
        rownames.append(str(i)+' '+colnames[sequence2use[bi,i]-1])
    rownamesxb.append(rownames)

print('--- Input Sequences (start at 1..T): -----')
print(sequence2use)
print('--- Target Sequences: -----')
print(sequence2pred)
print(sequence2pred.shape)
--- Input Sequences (start at 1..T): -----
[[1 3 5 2 4]
 [1 6 5 2 7]]
--- Target Sequences: -----
[[3 5 2 4 1]
 [6 5 2 7 1]]
(2, 5)

Now set up training data (X input and Y targets, and also positions)¶

training data are sequences of token id numbers, but start at 0 to get indexing right

In [4]:
#Now set up training data  Batch size is just 1
#B is 2 (b/c 2 sequences in training data)
#set up token sequences of id numbers AS column vectors
# so that each id indicates which unit is on in a one-hot vector

#Xtrain converts input sequence of ids (but start at 0 )
#Xtrain  = np.zeros((B,T,1)) #sequence2use.copy()
Xtrain  = np.zeros((B,T)) #sequence2use.copy()
for bi in range(B):
    for ti in range(T):
       Xtrain[bi,ti]=sequence2use[bi,ti]-1  #Xtrain_ids[bi,ti]-1  #index starts at 0 so subtract1

#make position information same size as X train
#Postrain=np.zeros((B,T,1))
Postrain=np.zeros((B,T))
for bi in range(B):
  #Postrain[bi,:,0] = np.arange(T)  #set Position to integer 1...T
  Postrain[bi,:] = np.arange(T)  #set Position to integer 1...T
print(Postrain.shape)

# make target values as the index
Ytrain=sequence2pred.copy()
for bi in range(B):
   for ti in range(T):
      Ytrain[bi,ti]=Ytrain[bi,ti]-1  #index starts at 0 so subtract1
#Ytrain = tf.expand_dims(Ytrain,axis=2)
#Ytrain = Ytrain[:,:,np.newaxis] #(Ytrain,axis=2)
print(Ytrain.shape)
print(Ytrain)

#make one hot vector of each id
def make1hot(M):
  Xtmp=np.zeros((2,T,V))  
  for i in range(2):
    for j in range(T):
      #print(M[i,j,0])
      Xtmp[i,j,int(M[i,j,0])]=1
  return Xtmp
if 0:
  Xtrain=make1hot(Xtrain)
  Ytrain=make1hot(Ytrain)
(2, 5)
(2, 5)
[[2 4 1 3 0]
 [5 4 1 6 0]]
In [5]:
Postrain
#Xtrain[0,]
Out[5]:
array([[0., 1., 2., 3., 4.],
       [0., 1., 2., 3., 4.]])

Now set up scaled-dot product constants and attention mask¶

In [6]:
#Now set up model related values
scale_value    = np.divide(1,np.sqrt(H)) #use H b/c it's dimension of Qmat, Kmat

#Make a mask
Mskl=torch.tril(torch.ones(T,T))
Msklbool=Mskl.bool()
print('scaling setup, causal Mask set up')
scaling setup, causal Mask set up
In [ ]:
 

Now build model layers that will learn transformation for Q,K,V matrices¶

In [18]:
#Now build model to learn transformation for Q,K,V matrices
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.Xembed    =  torch.nn.Embedding(V,E)
        self.Pos_Input =  torch.arange(T) #position info is just 1...T
        self.Pos_Embed =  torch.nn.Embedding(T,E)

        #now feed to Q,K,V transformations
        #<<<<------------------- H embedding size is used here Q,K
        self.Qmat       = torch.nn.Linear(E,H,bias=False)
        self.Kmat       = torch.nn.Linear(E,H,bias=False)
        self.Vmat       = torch.nn.Linear(E,V,bias=False)
        self.Smax       = torch.nn.Softmax(dim=-1) #apply softmax functional
        self.Smax2      = torch.nn.Softmax(dim=-1) #apply softmax functional
        self.Sigout     = torch.nn.Sigmoid()     #final output between 0,1
    def forward(self, x):
        xembed   = self.Xembed(x)
        posembed = self.Pos_Embed(self.Pos_Input.to(device))
        Xinputs  = xembed+posembed       #<<<--- comment out +Posembed see if it learns
        #print(' Xinptsshp:',Xinputs.shape)

        Q  = self.Qmat(Xinputs)
        K  = self.Kmat(Xinputs)
        V  = self.Vmat(Xinputs)
        QK = Q @ K.transpose(-2, -1) * scale_value #matrix multiplcation
        QKmasked = QK.masked_fill(~Msklbool.to(device), float('-inf')) # Mask out lower triangular with -inf

        Attn_Wts_smx = self.Smax(QKmasked)                   
        Attn_Wts     = torch.mul(Attn_Wts_smx,Mskl.to(device))         #elemnt wise x Mask
        Vout         = Attn_Wts @ V
        
        #<<< A full transformer would have a MLP hidden layer here, of a softmax, but we just squash it all to 0,1
        if 1: #use positional, inputs, info  
          VoutandXinput  = Vout+ Xinputs #skip connnection with add
          Prob_output    = self.Smax2(VoutandXinput)
        if 0: #experiment , no positional input info
          Prob_output    = self.Smax2(Vout) #andXinput)
        return Prob_output 
        
In [ ]:
 
In [19]:
# --------------------------------------------------------
#   Define training function
# --------------------------------------------------------

def train(model, device, optimizer, epoch):
    ''' This is called for each epoch.  
        Arguments:  the model, the device to run on, data loader, optimizer, and current epoch
    ''' 
    model.train()
    totloss=0
    for i in range(2):
        data,target = Xtrain[i],Ytrain[i]
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()                 #reset optimizer state
        output = model(data)                  #get predictions
        
        #if output is softmax, use log before nll
        loss = F.nll_loss(torch.log(output), target)     #get loss (nll_loss for softmax outputs)

        #if output is -1 to 1(logits) ,crossent will take log-softmax
        #loss = F.CrossEntropy(output, target)     #get loss (nll_loss for softmax outputs)
        loss.backward()                       #backprop loss
        optimizer.step()                      #update weights
    totloss +=loss
    if (epoch % 100 ==0):
          print('.... loss:',totloss)
        
def get_activation(name, activation):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
    
print('Train,test, support functions defined ')
Train,test, support functions defined 
In [20]:
# -------------------------------------------------
#  Get device  
#  (note, this is set up for 1 GPU device
#    if this were to run on a full GPU node with >1 gpu device, you would
#     want to get rank, world size info and set device id 
#     as in:   torch.cuda.set_device(local_rank) 
#     and then also run distributed initialization )
# -------------------------------------------------
use_cuda = torch.cuda.is_available() 
if use_cuda:
        num_gpu = torch.cuda.device_count()
        print('INFO,  cuda, num gpu:',num_gpu)
        device     = torch.cuda.current_device()
        print('environ visdevs:',os.environ["CUDA_VISIBLE_DEVICES"])
else:
        num_gpu = 0
        print('INFO, cuda not available')
        device  = torch.device("cpu")   
print('INFO, device is:', device)
INFO,  cuda, num gpu: 1
environ visdevs: 0
INFO, device is: 0
In [21]:
mymodel = MyNet().to(device)
In [22]:
# -------------------------------------------
#  Do training loop
# -------------------------------------------

# Dictionary to store activations
activations = {}
# Register hooks
mymodel.Qmat.register_forward_hook(get_activation('Qmat', activations))
mymodel.Kmat.register_forward_hook(get_activation('Kmat', activations))
mymodel.Vmat.register_forward_hook(get_activation('Vmat', activations))
mymodel.Smax.register_forward_hook(get_activation('Smax', activations))
mymodel.Smax2.register_forward_hook(get_activation('Smax2', activations))

optimizer = torch.optim.Adam(mymodel.parameters(), lr=lrate)

Xtrain=torch.tensor(Xtrain,dtype=int).to(device)
Ytrain=torch.tensor(Ytrain,dtype=int).to(device)
train_results = []
test_results  = []
for epoch in range(epochs):
        if (epoch % 100==0):
          print('INFO about to train epoch:',epoch)
        train(mymodel, device, optimizer, epoch)

print('INFO  done');
/scratch/etrain107/job_40600820/ipykernel_1568145/44505167.py:16: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  Xtrain=torch.tensor(Xtrain,dtype=int).to(device)
/scratch/etrain107/job_40600820/ipykernel_1568145/44505167.py:17: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  Ytrain=torch.tensor(Ytrain,dtype=int).to(device)
INFO about to train epoch: 0
.... loss: tensor(3.1632, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 100
.... loss: tensor(3.0198, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 200
.... loss: tensor(2.8617, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 300
.... loss: tensor(2.6861, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 400
.... loss: tensor(2.5046, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 500
.... loss: tensor(2.3337, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 600
.... loss: tensor(2.1789, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 700
.... loss: tensor(2.0397, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 800
.... loss: tensor(1.9144, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 900
.... loss: tensor(1.8008, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1000
.... loss: tensor(1.6969, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1100
.... loss: tensor(1.6009, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1200
.... loss: tensor(1.5115, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1300
.... loss: tensor(1.4276, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1400
.... loss: tensor(1.3493, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1500
.... loss: tensor(1.2769, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1600
.... loss: tensor(1.2101, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1700
.... loss: tensor(1.1482, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1800
.... loss: tensor(1.0905, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 1900
.... loss: tensor(1.0364, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2000
.... loss: tensor(0.9856, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2100
.... loss: tensor(0.9374, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2200
.... loss: tensor(0.8917, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2300
.... loss: tensor(0.8482, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2400
.... loss: tensor(0.8067, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2500
.... loss: tensor(0.7672, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2600
.... loss: tensor(0.7296, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2700
.... loss: tensor(0.6938, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2800
.... loss: tensor(0.6597, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 2900
.... loss: tensor(0.6274, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3000
.... loss: tensor(0.5966, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3100
.... loss: tensor(0.5674, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3200
.... loss: tensor(0.5397, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3300
.... loss: tensor(0.5133, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3400
.... loss: tensor(0.4883, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3500
.... loss: tensor(0.4646, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3600
.... loss: tensor(0.4421, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3700
.... loss: tensor(0.4209, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3800
.... loss: tensor(0.4010, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 3900
.... loss: tensor(0.3824, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4000
.... loss: tensor(0.3651, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4100
.... loss: tensor(0.3491, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4200
.... loss: tensor(0.3342, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4300
.... loss: tensor(0.3205, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4400
.... loss: tensor(0.3080, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4500
.... loss: tensor(0.2964, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4600
.... loss: tensor(0.2857, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4700
.... loss: tensor(0.2759, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4800
.... loss: tensor(0.2669, device='cuda:0', grad_fn=<AddBackward0>)
INFO about to train epoch: 4900
.... loss: tensor(0.2586, device='cuda:0', grad_fn=<AddBackward0>)
INFO  done
In [23]:
with torch.no_grad():
 PredOutput=[];Attn_Wts=[];Vmat=[]
 for i in range(B):
    data,target = Xtrain[i],Ytrain[i]
    data, target = data.to(device), target.to(device)
    output = mymodel(data)                  #get predictions
    #Q = np.squeeze(activations['Qmat'].detach().cpu())
    #K = np.squeeze(activations['Kmat'].detach().cpu())
    Vmat.append(np.squeeze(activations['Vmat'].detach().cpu()))
    Attn_Wts_smx=np.squeeze(activations['Smax'].detach().cpu())
    Attn_Wts.append(torch.mul(Attn_Wts_smx,Mskl))
    PredOutput.append(np.squeeze(activations['Smax2'].detach().cpu())) #
In [24]:
# -------------------------------------------

Now we want to examine the Outputs and Attention Wt matrix see how those matrices affect the 'the' predictions for each input sequence¶

In [14]:
# Helper functions to plot heat map
import matplotlib.pyplot as plt
import matplotlib
import matplotlib as mpl
def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw=None, cbarlabel="", **kwargs):
    if ax is None:
        ax = plt.gca()
    if cbar_kw is None:
        cbar_kw = {}

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    #cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    #cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
             rotation_mode="anchor")

    # Turn spines off and create white grid.
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im #, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts
In [ ]:
 
In [15]:
#OUTPUT 
Final_Pout_pred = PredOutput
#np.array(Out)#.detach().cpu()) #my_actvtns_output[layernum2get]

#Final_pred =my_attn_model.predict([Xtrain,Postrain])

for bi in range(B):
  print(' ---------- Output Predictions TxV (t-th row are predictions at time t) --')
  fig, ax = plt.subplots(figsize=(5,3))
  im = heatmap(Final_Pout_pred[bi],rownamesxb[bi], colnames,ax=ax,cmap="YlGn")
  texts = annotate_heatmap(im, valfmt="{x:.3f}")
  fig.tight_layout()
  plt.show()
 ---------- Output Predictions TxV (t-th row are predictions at time t) --
No description has been provided for this image
 ---------- Output Predictions TxV (t-th row are predictions at time t) --
No description has been provided for this image
In [ ]:
 
In [16]:
#  <<<<<<< ------------ Can you inteprety how Attn Wts are picking out predictions
AttnW_output = Attn_Wts

print('---  Note, the head size H was: ',H)
for bi in [0,1]:  #bi=0  plt.figure(figsize=(10,6)
  print(' ----------- Attention Wts TxT --------------')
  fig, ax = plt.subplots(figsize=(5,4))
  im = heatmap(AttnW_output[bi],rownamesxb[bi], rownamesxb[bi],ax=ax,cmap="YlGn", cbarlabel="attn wt")
  #im, cbar = heatmap(AttnW_output[bi],rownamesxb[bi], rownamesxb[bi],ax=ax,cmap="YlGn", cbarlabel="attn wt")
  texts = annotate_heatmap(im, valfmt="{x:.2f}")
  fig.tight_layout()
  plt.show()
---  Note, the head size H was:  20
 ----------- Attention Wts TxT --------------
No description has been provided for this image
 ----------- Attention Wts TxT --------------
No description has been provided for this image
In [17]:
# Here is the Vmat matrix (ie the 'value' matrix )
Vmat_values = Vmat

print('---  Note, the head size H was: ',H)
for bi in [0,1]:
  print(' ----------- Value matrix TxV--------------')
  fig, ax = plt.subplots(figsize=(5,4))
  im = heatmap(Vmat_values[bi],rownamesxb[bi], colnames,ax=ax,cmap="YlGn", cbarlabel="Value matrix")
  texts = annotate_heatmap(im, valfmt="{x:.2f}")
  fig.tight_layout()
  plt.show()
---  Note, the head size H was:  20
 ----------- Value matrix TxV--------------
No description has been provided for this image
 ----------- Value matrix TxV--------------
No description has been provided for this image
In [ ]:
 
In [ ]: