Greetings all,
I’m having some trouble with my code. I’m building a transformer, see, and I’ve got some problems at the inference stage. Here’s my code:
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size=embed_size
self.heads=heads
self.head_dim=embed_size // heads
assert (self.head_dim*heads==embed_size)
self.values=nn.Linear(self.embed_size, self.embed_size, bias=False)
self.keys=nn.Linear(self.embed_size, self.embed_size, bias=False)
self.queries=nn.Linear(self.embed_size, self.embed_size, bias=False)
self.fc_out=nn.Linear(heads*self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N=query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
values=self.values(values)
values= values.reshape(N, value_len, self.heads, self.head_dim)
keys=self.keys(keys)
keys= keys.reshape(N, key_len, self.heads, self.head_dim)
queries=self.queries(query)
queries=queries.reshape(N, query_len, self.heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
#energy shape: (N, heads, query_len, key_len)
if mask is not None:
energy=energy.masked_fill(mask==0, float("-1e20"))
attention=torch.softmax(energy/(self.embed_size**(1/2)), dim=3)
out=torch.einsum("nhql,nlhd->nqhd", [attention,values]).reshape(N, query_len, self.heads*self.head_dim)
#out shape= (N, query_len, heads, head_dim)
out=self.fc_out(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention=SelfAttention(embed_size, heads)
self.norm1=nn.LayerNorm(embed_size)
self.norm2=nn.LayerNorm(embed_size)
self.feed_forward=nn.Sequential(
nn.Linear(embed_size, forward_expansion*embed_size),
nn.ReLU(),
nn.Linear(forward_expansion*embed_size, embed_size)
)
self.dropout=nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention=self.attention(value, key, query, mask)
x=self.norm1(query+self.dropout(attention))
forward=self.feed_forward(x)
out=self.norm2(x+self.dropout(forward))
return out
class Encoder(nn.Module):
def __init__(
self,
src_vocab_size,
embed_size,
num_layers,
heads,
device,
forward_expansion,
dropout,
max_length
):
super(Encoder, self).__init__()
self.embed_size=embed_size
self.device=device
self.word_embedding=nn.Embedding(src_vocab_size, embed_size) # Embeds src to size vocab and embed
self.position_embedding=nn.Embedding(max_length, embed_size) #
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion
)
for _ in range(num_layers)]
)
self.dropout=nn.Dropout(dropout)
def forward(self, x, mask):
N, seq_length=x.shape # N is the batch size, seq_length is the size of n-gram
positions=torch.arange(0, seq_length).expand(N, seq_length).to(self.device) #positions
out= self.dropout((self.word_embedding(x)+self.position_embedding(positions))) #dropout of word and position embedder
for layer in self.layers:
out=layer(out, out, out, mask)
return out
class DecoderBlock(nn.Module):
def __init__(self, embed_size, heads, forward_expansion, dropout, device):
super(DecoderBlock, self).__init__()
self.norm = nn.LayerNorm(embed_size)
self.attention= SelfAttention(embed_size, heads)
self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
self.dropout=nn.Dropout(dropout)
def forward(self, x, value, key, src_mask, trg_mask):
attention=self.attention(x,x,x,trg_mask)
query=self.norm(x+self.dropout(attention))
out=self.transformer_block(value, key, query, src_mask)
return out
class Decoder(nn.Module):
def __init__(
self,
trg_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
device,
max_length
):
super(Decoder, self).__init__()
self.device = device
self.word_embedding=nn.Embedding(trg_vocab_size, embed_size)
self.position_embedding=nn.Embedding(max_length, embed_size)
self.layers=nn.ModuleList(
[DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)]
)
self.fc_out=nn.Linear(embed_size, trg_vocab_size)
self.dropout=nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask, trg_mask):
N, seq_length=x.shape
positions=torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
x=self.dropout((self.word_embedding(x)+self.position_embedding(positions)))
for layer in self.layers:
x= layer(x, enc_out, enc_out, src_mask, trg_mask)
out=self.fc_out(x)
return out
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
trg_pad_idx,
embed_size=128,
num_layers=3,
forward_expansion=2,
heads=4,
dropout=0.05,
device="cpu",
max_length=6
):
super(Transformer, self).__init__()
self.encoder = Encoder(
src_vocab_size,
embed_size,
num_layers,
heads,
device,
forward_expansion,
dropout,
max_length
)
self.decoder = Decoder(
trg_vocab_size,
embed_size,
num_layers,
heads,
forward_expansion,
dropout,
device,
max_length,
)
self.src_pad_idx=src_pad_idx
self.trg_pad_idx= trg_pad_idx
self.device=device
self.sm_out=nn.Softmax(dim=2)
def make_src_mask(self, src):
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
#N,1,1,src_len
return src_mask.to(self.device)
def make_trg_mask(self, trg):
N, trg_len=trg.shape
trg_mask=torch.tril(torch.ones((trg_len, trg_len))).expand(
N, 1, trg_len, trg_len
)
return trg_mask.to(self.device)
def forward(self, src, trg):
src_mask = self.make_src_mask(src)
trg_mask=self.make_trg_mask(trg)
enc_src= self.encoder(src, src_mask)
out=self.decoder(trg, enc_src, src_mask, trg_mask)
out=self.sm_out(out)
return out
Then I have a batching function and random input generation.
def get_batches(arr_x, arr_y, batch_size):
# iterate through the arrays
prv = 0
for n in range(batch_size, arr_x.shape[0]+1, batch_size):
x = arr_x[prv:n,:]
y = arr_y[prv:n,:]
prv = n
yield x, y
import torch.optim as optim
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
src2=np.random.randint(8,size=(128,6))+1
trg2=src2+5
print(trg2) #outputs a 128,6 array of random ints 5 more than src2
Training for 100 epochs, lr is high so it catches on pretty quickly
model=Transformer(src_vocab_size=14, trg_vocab_size=14, src_pad_idx=0, trg_pad_idx=0).to(device)
def train_model(src, trg, epochs=101, batch_size=32, classes=14):
print(model)
opt=torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn=nn.functional.cross_entropy
total_loss=0
train_loss_list, validation_loss_list = [], []
total_tokens=0
for e in range(epochs):
model.train()
loss=0
if e%50==0:
print("-"*25, f"Epoch {e + 1}","-"*25)
for x, y in get_batches(src, trg, batch_size):
x, y = torch.from_numpy(x.astype('int64')).to(device), torch.from_numpy(y.astype('int64')).to(device)
pred=model(x, y)
if e%50==0:
print(pred.shape)
print(torch.argmax(pred,dim=2)[0])
print(y[0])
loss=loss_fn(pred, torch.nn.functional.one_hot(y, num_classes=classes).type(torch.FloatTensor))
total_tokens+=batch_size
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.detach().item()
train_loss_list += [total_loss/total_tokens]
if e%50==0: print(f"Training loss: {loss/batch_size:.4f}, Total loss: {total_loss:.4f}")
print('...')
return train_loss_list
train_loss_list = train_model(src=src2,trg=trg2)
Now for my woes. When I try to infer with the model on a simple input, I get garbage output.
def greedy_decode(model, src, max_len):
ys = torch.zeros(1, 6).type_as(src)
print(ys)
for i in range(max_len):
prob=model(src, ys)
print(prob)
next_word = torch.argmax(prob, dim = 2)
print(next_word)
next_word = next_word.data[0]
print(next_word.data[i])
ys[0,i] += next_word.data[i]
return ys
model.eval()
src = torch.LongTensor([[2,3,4,5,6,7]])
print(greedy_decode(model, src, max_len=6))
tensor([[0, 0, 0, 0, 0, 0]])
tensor([[[0.0288, 0.0460, 0.0148, 0.0124, 0.0357, 0.0194, 0.0374, 0.0296,
0.1791, 0.1655, 0.1428, 0.0337, 0.0816, 0.1730],
[0.0349, 0.0780, 0.0151, 0.0332, 0.0361, 0.0238, 0.0523, 0.0159,
0.0896, 0.0648, 0.1253, 0.0349, 0.2191, 0.1770],
[0.0480, 0.0754, 0.0158, 0.0214, 0.0664, 0.0754, 0.0651, 0.0177,
0.2211, 0.0690, 0.0944, 0.0270, 0.1367, 0.0665],
[0.0338, 0.0433, 0.0159, 0.0246, 0.0476, 0.0320, 0.0393, 0.0061,
0.1299, 0.1012, 0.1060, 0.0221, 0.1791, 0.2190],
[0.0320, 0.0527, 0.0133, 0.0271, 0.0659, 0.0392, 0.0235, 0.0104,
0.0354, 0.1256, 0.0601, 0.0214, 0.1289, 0.3646],
[0.0267, 0.0525, 0.0186, 0.0310, 0.0473, 0.0292, 0.0658, 0.0112,
0.1424, 0.1685, 0.1065, 0.0492, 0.0946, 0.1564]]],
grad_fn=<SoftmaxBackward0>)
tensor([[ 8, 12, 8, 13, 13, 9]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
[3.0471e-02, 1.0963e-01, 2.6189e-02, 2.8180e-02, 4.6828e-02,
2.7923e-02, 4.8535e-02, 1.7435e-02, 2.1706e-01, 5.1900e-02,
1.0358e-01, 3.3658e-02, 1.9026e-01, 6.8364e-02],
[3.7262e-02, 7.9151e-02, 1.8621e-02, 1.5883e-02, 6.6042e-02,
7.3598e-02, 4.7058e-02, 1.4851e-02, 3.6171e-01, 5.1541e-02,
7.0141e-02, 2.7503e-02, 1.0729e-01, 2.9345e-02],
[3.1110e-02, 5.1011e-02, 1.9348e-02, 2.1894e-02, 5.1082e-02,
3.4915e-02, 3.4925e-02, 5.6201e-03, 2.1383e-01, 8.9307e-02,
9.1074e-02, 2.3128e-02, 1.9180e-01, 1.4095e-01],
[3.1558e-02, 6.9342e-02, 1.7462e-02, 2.6760e-02, 7.8564e-02,
4.4574e-02, 2.2642e-02, 1.0699e-02, 5.6768e-02, 1.4076e-01,
5.7005e-02, 2.4431e-02, 1.3645e-01, 2.8298e-01],
[2.3861e-02, 6.0398e-02, 2.2116e-02, 2.6945e-02, 4.8900e-02,
2.9926e-02, 6.3012e-02, 9.8275e-03, 2.0789e-01, 1.6895e-01,
9.8798e-02, 4.9619e-02, 8.3311e-02, 1.0644e-01]]],
grad_fn=<SoftmaxBackward0>)
tensor([[ 8, 8, 8, 8, 13, 8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
[5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
[2.4091e-02, 6.6588e-02, 2.0767e-02, 1.1727e-02, 5.5241e-02,
6.0745e-02, 3.3838e-02, 1.2357e-02, 5.2257e-01, 3.3011e-02,
4.7262e-02, 2.3244e-02, 7.4178e-02, 1.4385e-02],
[2.5230e-02, 5.1829e-02, 2.1725e-02, 1.9427e-02, 4.9629e-02,
3.4386e-02, 3.0234e-02, 5.2530e-03, 3.1320e-01, 7.1912e-02,
7.7694e-02, 2.3634e-02, 1.7828e-01, 9.7569e-02],
[2.8949e-02, 8.1070e-02, 2.2353e-02, 2.6622e-02, 8.7250e-02,
4.9413e-02, 2.2407e-02, 1.1211e-02, 8.7391e-02, 1.4620e-01,
5.3241e-02, 2.5863e-02, 1.3452e-01, 2.2351e-01],
[2.0373e-02, 6.2544e-02, 2.5682e-02, 2.4110e-02, 4.8708e-02,
2.9638e-02, 6.1629e-02, 9.1208e-03, 2.8317e-01, 1.5011e-01,
8.9950e-02, 4.8338e-02, 7.1531e-02, 7.5104e-02]]],
grad_fn=<SoftmaxBackward0>)
tensor([[ 8, 8, 8, 8, 13, 8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
[5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
[6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
[2.0070e-02, 5.0681e-02, 2.4773e-02, 1.8060e-02, 4.9905e-02,
3.3661e-02, 2.7250e-02, 5.5127e-03, 4.0430e-01, 5.2429e-02,
6.6296e-02, 2.2382e-02, 1.6011e-01, 6.4569e-02],
[2.6628e-02, 8.8863e-02, 2.7116e-02, 2.6948e-02, 9.6683e-02,
5.3906e-02, 2.1981e-02, 1.2533e-02, 1.2193e-01, 1.3180e-01,
5.0072e-02, 2.6949e-02, 1.3609e-01, 1.7850e-01],
[1.7873e-02, 6.2392e-02, 2.8520e-02, 2.2571e-02, 4.9959e-02,
2.9230e-02, 6.0232e-02, 9.2869e-03, 3.4494e-01, 1.2532e-01,
8.2190e-02, 4.8685e-02, 6.4112e-02, 5.4695e-02]]],
grad_fn=<SoftmaxBackward0>)
tensor([[ 8, 8, 8, 8, 13, 8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
[5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
[6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
[3.4224e-04, 3.4134e-04, 3.1641e-04, 9.3156e-05, 1.9721e-04,
1.2940e-04, 1.0635e-04, 1.6255e-04, 9.9706e-01, 1.8321e-04,
1.1403e-04, 2.0476e-04, 5.8318e-04, 1.6127e-04],
[2.3426e-02, 9.6285e-02, 3.3548e-02, 2.5118e-02, 1.0165e-01,
5.6494e-02, 2.1736e-02, 1.3782e-02, 1.7155e-01, 1.1895e-01,
4.6460e-02, 2.7188e-02, 1.2749e-01, 1.3631e-01],
[1.4750e-02, 6.0416e-02, 3.1490e-02, 1.9243e-02, 4.8390e-02,
2.8151e-02, 5.5594e-02, 9.3160e-03, 4.1814e-01, 1.0462e-01,
7.4928e-02, 4.4255e-02, 5.3903e-02, 3.6806e-02]]],
grad_fn=<SoftmaxBackward0>)
tensor([[8, 8, 8, 8, 8, 8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
[5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
[6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
[3.4224e-04, 3.4134e-04, 3.1641e-04, 9.3156e-05, 1.9721e-04,
1.2940e-04, 1.0635e-04, 1.6255e-04, 9.9706e-01, 1.8321e-04,
1.1403e-04, 2.0476e-04, 5.8318e-04, 1.6127e-04],
[7.2871e-04, 9.4293e-04, 4.8325e-04, 1.6221e-04, 4.3626e-04,
3.0433e-04, 1.8114e-04, 7.3524e-04, 9.9360e-01, 4.9332e-04,
7.6330e-05, 5.0833e-04, 1.0233e-03, 3.2313e-04],
[1.1731e-02, 5.6806e-02, 3.3593e-02, 1.6554e-02, 4.6226e-02,
2.6792e-02, 4.8839e-02, 9.7105e-03, 4.9776e-01, 7.5886e-02,
5.9116e-02, 4.2508e-02, 4.8548e-02, 2.5929e-02]]],
grad_fn=<SoftmaxBackward0>)
tensor([[8, 8, 8, 8, 8, 8]])
tensor(8)
tensor([[8, 8, 8, 8, 8, 8]]) # expected [7,8,9,10,11,12]
Where am I going wrong? Am I on the right path here? Is my model learning anything?