The code below is the complete full code of the model, so if you copy and paste it, it should run perfectly fine
I am sorry for posting such a long code, I wish I can just limit the code to the problem that I am trying to solve but I have been stuck for a week struggling to even find what the problem is so I have to post the full code. I am using pytroch lightning and trying to build a pointer generator with transformer based on the paper https://arxiv.org/pdf/1704.04368.pdf and https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1194/reports/custom/15784595.pdf.
I am trying to completely mirror the architecture talked about in the papers. In order to test to the model I created a task that classifies even number and odd number so: given a list of numbers {0, 2, 5, 7, 3, 5, 1, 8, 4, 19}, the model should return {0, [oddToken], 5, 7, 3, 5, 1. [oddToken], [evenToken], 2, 8, 4, [evenToken], 19} 0, 19 is start and end token respectively.
In training, the model can only achieve accuracy of 50% and then onto around epoch 2, the training loss and everything else goes to nan.
I tried to mirror the architecture completely but when computing pGen in the class getPgen
, I have to add an additional layer normalization because otherwise pGen either swings to 0 or 1
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.autograd import Variable
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import math, copy
c = copy.deepcopy
class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self, tgt_size, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(reduction='sum')
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.tgt_size = tgt_size
def forward(self, y, target):
# y has size (batch_size x tgt_length, vocab_size)
assert y.size(1) == self.tgt_size
smooth = self.smoothing / (self.tgt_size - 2)
true_dist = torch.zeros(size=(target.size(0), target.size(1), y.size(-1))).fill_(smooth)
true_dist.scatter_(-1, target.unsqueeze(-1), 1)
true_dist = true_dist.view(-1, true_dist.size(-1))
return self.criterion(torch.log(y), Variable(true_dist, requires_grad=False))
class Batch:
"Object for holding a batch of data with mask during training."
def __init__(self, src, trg=None, pad=0):
self.src = src
self.src_mask = (src != pad).unsqueeze(-2)
if trg is not None:
self.trg = trg[:, :-1]
self.trg_y = trg[:, 1:]
self.trg_mask = \
self.make_std_mask(self.trg, pad)
self.ntokens = (self.trg_y != pad).data.sum()
@staticmethod
def make_std_mask(tgt, pad):
"Create a mask to hide padding and future words."
tgt_mask = (tgt != pad).unsqueeze(-2)
tgt_mask = tgt_mask & Variable(
subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
return tgt_mask
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
class getPgen(pl.LightningModule):
def __init__(self, d_model):
super(getPgen, self).__init__()
self.linear1 = nn.Linear(d_model, 1)
self.linear2 = nn.Linear(d_model, 1)
self.linear3 = nn.Linear(d_model, 1)
self.norm = LayerNorm(13)
def forward(self, contextVector, x, out):
# all (batch_size, tgt_length, d_model)
output = self.linear1(contextVector) + self.linear2(x) + self.linear3(out)
a = F.sigmoid(self.norm.forward(output.squeeze(-1)))
return a
# return (batch_size, tgt_length)
class GetContextVector(pl.LightningModule):
def __init__(self):
super(GetContextVector, self).__init__()
def forward(self, memory, attentions):
# memory (batch_size, src_length, d_model)
# attentions (batch_size, tgt_length, src_length)
attentions = F.softmax(attentions, dim=-1)
output = torch.tensor([])
for count in range(0, attentions.size(1)):
out = self.one_time_step(attentions[:, count], memory)
output = torch.cat((output, out), dim=1)
return output
# return (batch_size, tgt_length, d_model)
def one_time_step(self, attention, memory):
output = torch.zeros(attention.size(0), memory.size(2))
# attention (batch_size, src_length)
# memory (batch_size, src_length, d_model)
for count in range(0, attention.size(1)):
attn = attention[:, count].unsqueeze(1)
# attention (batch_size, 1)
memoryPart = memory[:, count] * attn
# memoryPart (batch_size, d_model)
output += memoryPart
return output.unsqueeze(1)
# return (batch_size, 1, d_model)
class GenerateAttention(pl.LightningModule):
def __init__(self, d_model):
super(GenerateAttention, self).__init__()
self.linears1 = nn.Linear(d_model, d_model*2)
self.linears2 = nn.Linear(d_model, d_model*2)
self.linears3 = nn.Linear(d_model*2, 1)
def forward(self, encoded, decoded):
# context (batch_size, src_length, d_model)
# prediction (batch_size, tgt_length, d_model)
attentionScore = torch.tensor([])
for timeStep in range(0, decoded.size(1)):
if attentionScore is None:
attentionScore = self.generate_attention(encoded, decoded[:, timeStep, :].unsqueeze(1), timeStep)
else:
attentionScore = torch.cat((attentionScore, self.generate_attention(encoded, decoded[:, timeStep, :].unsqueeze(1), timeStep)), dim=1)
return attentionScore
# attentionScore (batch_size, tgt_length, src_length)
# each row of src_length represents attention weight of each input.
def generate_attention(self, context, prediction, timeStep):
# context (batch_size, src_length, d_model)
# predictionSoFar (batch_size, 1, d_model)
context = self.linears1(context)
prediction = self.linears2(prediction)
c_p = F.tanh((context + prediction))
a = self.linears3(c_p).transpose(-1, -2)
return a
# return dimension of (batch_size, 1, src_length)
class PointerGenerator(pl.LightningModule):
def __init__(self):
super(PointerGenerator, self).__init__()
pass
# x: (batch_size, tgt_length, src_length)
def forward(self, x):
return F.softmax(x, dim=-1)
class Generator(pl.LightningModule):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.linear1 = nn.Linear(d_model * 2, d_model*4)
self.linear2 = nn.Linear(d_model*4, vocab)
def forward(self, decoded, contextVector):
# decoded (batch_size, tgt_length, d_model)
# contextVector (batch_size, tgt_length, d_model)
combine = torch.cat((decoded, contextVector), dim=-1)
return F.softmax(self.linear2(self.linear1(combine)), dim=-1)
# return generativeProb (batch_size, tgt_length, vocab)
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class Encoder(pl.LightningModule):
"Core encoder is a stack of N layers"
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class LayerNorm(pl.LightningModule):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.features = features
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
a = self.a_2 * (x - mean)
b = a / (std + self.eps) + self.b_2
return b
class SublayerConnection(pl.LightningModule):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
return x + self.dropout(sublayer(self.norm(x)))
class EncoderLayer(pl.LightningModule):
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask):
# at this point x has just been passed through word embedding and positional embedding
# dimension (batch_size, length, d_model)
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)
class Decoder(pl.LightningModule):
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for count in range(0, len(self.layers)):
x, x1 = self.layers[count](x, memory, src_mask, tgt_mask, count)
return self.norm(x), x1
class DecoderLayer(pl.LightningModule):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask, currentHead):
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
pointerDecoded = x.clone()
return self.sublayer[2](x, self.feed_forward), pointerDecoded
# first object for generative, second object for pointer
# both (batch_size, tgt_length, d_model)
def attention(query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
x = torch.matmul(p_attn, value)
return x, p_attn
class MultiHeadedAttention(pl.LightningModule):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1)
nbatches = query.size(0)
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
x = self.linears[-1](x)
return x
class PositionwiseFeedForward(pl.LightningModule):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Embeddings(pl.LightningModule):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
a = self.lut(x) * math.sqrt(self.d_model)
return a
class PositionalEncoding(pl.LightningModule):
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)
def duplicate(x, sample):
output = torch.tensor([])
for count in range(0, sample):
output = torch.cat((output, x), dim=-1)
return output.long()
def getAccuracy(trg_hat, trg, generative_token, src): # x (batch_size, tgt_length, full vocab size) trg (batch_size, tgt_length)
totalValAcc = 0
totalValAccToken = 0
trg = trg.contiguous().view(-1)
_, index = torch.max(trg_hat, dim=-1)
index = index.contiguous().view(-1)
correct = list((trg == index)).count(True)
totalValAcc += correct
totalValAccToken += trg.size(0)
return totalValAcc / totalValAccToken
class EncoderDecoder(pl.LightningModule):
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator, generateAttention, pointerGenerator, getPGen, getContextVector):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
self.generateAttention = generateAttention
self.pointerGenerator = pointerGenerator
self.customOptimizer = None
self.automatic_optimization = False
self.totalValLoss = 0
self.totalValToken = 0
self.accuracyScore = 0
self.minValLoss = 10000
self.getPGen = getPGen
self.getContextVector = getContextVector
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, tgt, src_mask, tgt_mask):
encoded = self.encode(src, src_mask)
decoded, pointerDecoded = self.decode(encoded, src_mask, tgt, tgt_mask)
return decoded, pointerDecoded, encoded
# return (batch_size, tgt_length, d_model) x 2, encoded (batch_size, src_length, d_model)
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
# self.src_embed(src) pass through turn dimension of (batch_size, length) into (batch_size, length, d_model) through nn.embedding
# then positional embedding is also applied
# it then passes through encoder layer N times, output is still the same (batch_size, length, d_model)
def decode(self, memory, src_mask, tgt, tgt_mask):
decoded, pointerDecoded = self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
return decoded, pointerDecoded
# decoder takes in memory which is the output of the encoder (batch_size, length, d_model),
# and self.tgt_embed(tgt) this passes target value through word embedding and positional embedding,
# and give (batch_size, tgt_length, d_model) (tgt_length is original length - 1)
def get_final_answer(self, encoded, decoded, pointerDecoded, target, source):
attentions = self.generateAttention(encoded, pointerDecoded) # attentions (batch_size, tgt_length-1, src_length)
attentions = F.softmax(attentions, dim=-1)
pointerProb = self.one_hot_format(attentions, source) # (batch_size, tgt_length, V)
contextVector = self.getContextVector(encoded, attentions) # context vector (batch_size, tgt_length, d_model)
generativeProbabilityD = self.generator(decoded, contextVector) # shape (batch_size, tgt_length, V)
x = self.tgt_embed(target) # x (batch_size, tgt_length, d_model)
pGen = self.getPGen(contextVector, x, decoded).unsqueeze(-1) # pGen (batch_size, tgt_length, 1)
finalOut = pointerProb * (pGen) + generativeProbabilityD * (1 - pGen) # (batch_size, tgt_length, V)
return finalOut
def one_hot_format(self, attention, index):
# index (batch_size, src_length)
# value (batch_size, tgt_length, src_length)
final = None
for count in range(0, index.size(0)):
oneSample = index[count].unsqueeze(0)
finalSample = None
for count in range(0, attention.size(1)):
if finalSample is None:
finalSample = oneSample
else:
finalSample = torch.cat((finalSample, oneSample), dim=0)
finalSample = finalSample.unsqueeze(0)
if final is None:
final = finalSample
else:
final = torch.cat((final, finalSample), dim=0)
a = Variable(torch.zeros(attention.size(0), attention.size(1), V), requires_grad=True)
b = a.clone()
b.scatter_(-1, final.long(), attention)
return b
def training_step(self, batch, batch_idx): # batch should contain both x, y label
if self.customOptimizer is None:
self.customOptimizer = self.optimizers()
batch = Batch(batch[0], batch[1])
decoded, pointerDecoded, encoded = self(batch.src, batch.trg, batch.src_mask, batch.trg_mask) # out (batch_size, tgt_length, d_model), # memory (batch_size, src_length, d_model)
finalOut = self.get_final_answer(encoded, decoded, pointerDecoded, batch.trg_y, batch.src)
criterion = LabelSmoothing(tgt_size=V, smoothing=0)
loss = criterion(finalOut.contiguous().view(-1, finalOut.size(-1)), batch.trg_y) / batch.ntokens
loss.backward()
self.customOptimizer.step()
self.customOptimizer.zero_grad()
self.log("train_loss", loss)
if batch_idx % 5 == 0:
print(loss)
if loss < 0.001:
print("")
return {'idx': batch_idx}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
def train_dataloader(self):
randomSeed = 0
totalTrainingSample = 3000
train = data_gen(V, randomSeed, totalTrainingSample)
return DataLoader(train, batch_size=30, shuffle=False, num_workers=1, pin_memory=True)
def val_dataloader(self):
randomSeed = 1
totalTrainingSample = 1000
train = data_gen(V, randomSeed, totalTrainingSample) # second parameter is random seed
return DataLoader(train, batch_size=10, shuffle=False, num_workers=1, pin_memory=True)
def validation_step(self, batch, batch_idx):
if batch_idx == 0:
self.totalValLoss = 0
self.totalValToken = 0
batch = Batch(batch[0], batch[1])
decoded, pointerDecoded, encoded = self(batch.src, batch.trg, batch.src_mask, batch.trg_mask) # out (batch_size, tgt_length, d_model), # memory (batch_size, src_length, d_model)
finalOut = self.get_final_answer(encoded, decoded, pointerDecoded, batch.trg_y, batch.src)
criterion = LabelSmoothing(tgt_size=V, smoothing=0)
loss = criterion(finalOut.contiguous().view(-1, finalOut.size(-1)), batch.trg_y) / batch.ntokens
self.totalValLoss += loss * batch.ntokens
self.totalValToken += batch.ntokens
if batch_idx == 99:
self.totalValLoss = self.totalValLoss / self.totalValToken
print(f"valLoss: {self.totalValLoss}")
if self.minValLoss > self.totalValLoss:
self.minValLoss = self.totalValLoss
torch.save(self.state_dict(), "reverseNumberTask{0}".format(self.current_epoch))
self.log("val_loss", self.totalValLoss)
return {"x": finalOut, "trg": batch.trg_y, "index": batch_idx, "src": batch.src}
def validation_step_end(self, batch):
x, trg, idx, src = batch['x'], batch['trg'], batch['index'], batch['src']
self.accuracyScore += getAccuracy(x, trg, generative_token, src)
if idx == 99:
self.log('accuracy', self.accuracyScore/100)
print(self.accuracyScore/100)
self.accuracyScore = 0
def data_gen(V, randomSeed, totalTrainingSample):
evenToken = [V-2]
oddToken = [V-1]
y = []
np.random.seed(randomSeed)
x = torch.from_numpy(np.random.randint(1, V - 3, size=(totalTrainingSample, src_length)))
x[:, 0] = 0
x[:, -1] = V-3
odd = [[item.item() for item in sample if item % 2 == 1 and item not in [0, V-3]] for sample in x]
even = [[item.item() for item in sample if item % 2 == 0 and item not in [0, V-3]] for sample in x]
for count in range(0, len(odd)):
oddSample = oddToken + odd[count] + oddToken
evenSample = evenToken + even[count] + evenToken
y.append([1] + oddSample + evenSample + [V-3])
data = list(zip(x, torch.tensor(y)))
return data
generative_vocab = 2
src_vocab = 20
src_length = 10
tgt_length = 14
V = generative_vocab + src_vocab
generative_token = torch.tensor([V-2, V-1])
d_model = 512
d_ff = 2048
dropout = 0.1
h = 8
N = 2
attn = MultiHeadedAttention(h, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
def fake_mask(size):
return torch.ones(size=(size, size)).unsqueeze(0)
if __name__ == '__main__':
if True:
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn),
c(ff), dropout), N),
nn.Sequential(Embeddings(d_model, V), c(position)),
nn.Sequential(Embeddings(d_model, V), c(position)),
Generator(d_model, V), GenerateAttention(d_model), PointerGenerator(), getPgen(d_model), GetContextVector())
earlyStopping = EarlyStopping(monitor='val_loss', patience=3)
trainer = pl.Trainer(max_epochs=10, callbacks=[earlyStopping])
trainer.fit(model)
if True:
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn),
c(ff), dropout), N),
nn.Sequential(Embeddings(d_model, V), c(position)),
nn.Sequential(Embeddings(d_model, V), c(position)),
Generator(d_model, V), GenerateAttention(d_model), PointerGenerator(), getPgen(d_model), GetContextVector())
model.load_state_dict(torch.load("reverseNumberTask1"))
model.eval()
data = data_gen(V, 1, 10)
accuracy = 0
data = DataLoader(data, batch_size=10, shuffle=False, num_workers=2, pin_memory=True)
for i, batch in enumerate(data):
batch = Batch(batch[0], batch[1])
decoded, pointerDecoded, encoded = model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask) # out (batch_size, tgt_length, d_model), # memory (batch_size, src_length, d_model)print(getAccuracy(finalOut, batch.trg_y, generative_token, batch.src))
finalOut = model.get_final_answer(encoded, decoded, pointerDecoded, batch.trg_y, batch.src)
print(getAccuracy(finalOut, batch.trg_y, generative_token, batch.src))