I’m trying to do language modeling on a custom dataset using nn.TransformerEncoder
. I’m using https://github.com/pytorch/examples/tree/master/word_language_model as a reference.
Previously, I used Google’s Trax and its TransformerLM
model to train a transformer with this dataset, based on this example: https://github.com/jalammar/jalammar.github.io/blob/master/notebooks/Trax_TransformerLM_Intro.ipynb. There, I reduced the Adam learning rate to 1e-04
, replaced the data with my actual dataset, adjusted the hyperparameters for my use-case, and managed to get very good results.
I’m now trying to replicate the same result in PyTorch, but without much luck.
Here are the things I modified from the word_language_model
example above:
- Changed the batching and dataset load logic. Since my dataset consists of separate sentences, my input data is of shape
(num_examples, seq_len)
, and each element is a token index, with 0 reserved for padding. I generate inputs of shape(max_seq_len, batch_size)
. - Calculated a padding mask that contains
True
wherever the original data has 0 (the padding token), passing it assrc_key_padding_mask
tonn.TransformerEncoder
. - Added an Adam optimizer instead of the previously inline LR annealing on plateau.
- Removed
F.log_softmax
, so now returning raw data from the lastnn.Linear
layer, and changed the loss tonn.CrossEntropyLoss(ignore_index=0)
. - Added accuracy measurements by masking the final logits using the key padding mask, taking an argmax, and comparing to the actual output.
My training plateaus within a few dozen batches to an accuracy of ~10% and a constant loss, and simply predicts the most common overall token, no matter the input (this gives an accuracy of ~10%, as that token is around 10% of the training data). I’ve tried using a learning rate scheduler, adding some warmup steps, played around with hyperparameters and the initial learning rate, tried changing the initialization of the embedding and final nn.Linear
weights to nn.init.xavier_uniform_
, but nothing helps.
For comparison, I tried randomly shuffling my data around to make it nonsense, and the model arrives at similar numbers. So I’m pretty sure it’s learning nothing.
Here’s an example of how my training data looks (right before it goes into the model):
data =
[[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 107, 81, 81],
[115, 4, 111, 46],
[ 5, 5, 5, 5],
[ 80, 41, 63, 61],
# some zeroes here, starting at different rows,
# depending on the size of the example
...]]
target =
[ 2, 3, 115, 5, 80, 42, ...
0, 0, 0, 0, 0, 0, ...
2, 107, 4, 5, 41, 96, ...
0, 0, 0, 0, 0, 0, ...]]
(Almost every example starts with 1 2
, so I imagined it’d be easy to fit on at least this, but not even that happens.)
Here’s the model:
# PositionalEncoding definition from word_language_model omitted
# for brevity
class Transformer(nn.Module):
def __init__(self, n_tokens, d_model, n_heads, d_ff, n_layers, dropout=0.1,
max_len=4096, activation='relu'):
super(Transformer, self).__init__()
self.mask = None
self.d_model = d_model
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
self.embedding = nn.Embedding(n_tokens, d_model, padding_idx=0)
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
self.tf_encoder = nn.TransformerEncoder(enc_layer, n_layers)
self.decoder = nn.Linear(d_model, n_tokens)
self.init_weights()
def _from_binary_mask(self, mask):
return mask.float() \
.masked_fill(mask == False, float('-inf')) \
.masked_fill(mask == True, float(0.0))
def _generate_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return self._from_binary_mask(mask)
def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.embedding.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.weight)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
# Alternative I tried:
# nn.init.xavier_uniform_(self.embedding.weight)
# nn.init.xavier_uniform_(self.decoder.weight)
# nn.init.normal_(self.decoder.bias, 1e-6)
def forward(self, src, use_mask=True, src_key_padding_mask=None):
if use_mask:
device = src.device
if self.mask is None or self.mask.size(0) != len(src):
mask = self._generate_mask(len(src)).to(device)
self.mask = mask
else:
self.mask = None
# embed and add positional information
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.tf_encoder(src, self.mask,
src_key_padding_mask=src_key_padding_mask)
output = self.decoder(output)
return output
And here’s my training loop:
model = model.Transformer(n_tokens, args.d_model, args.n_heads, args.d_ff,
args.n_layers, args.dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-3)
# init data, etc.
def train(train_data):
model.train()
bs = args.batch_size
total_loss = 0.0
total_correct = 0
total_total = 0
start_time = time.time()
for batch_i in range(0, train_data.size(0) // bs):
transposed_data, target = get_batch(train_data, batch_i) # (bs, seq_len), (bs*seq_len)
data = transposed_data.transpose(0, 1).type(torch.LongTensor).to(device) # (seq_len, bs)
target = target.type(torch.LongTensor).to(device) # (bs*seq_len)
optimizer.zero_grad()
key_padding_mask = calc_key_padding_mask(transposed_data, target) # (bs, seq_len), 1 if not padding
output = model(data, src_key_padding_mask=torch.logical_not(key_padding_mask))
output = output.view(-1, n_tokens) # (bs*seq_len, n_tokens)
logits = F.log_softmax(output, dim=-1) # (bs*seq_len, n_tokens)
masked_logits = apply_mask_to_logits(logits, key_padding_mask) # (bs*seq_len, n_tokens), sets element to -inf if padding
confidences, predictions = torch.max(masked_logits.exp(), 1) # (bs*seq_len), confidence will be 0 if padding
total_correct += (torch.logical_and(predictions == target, confidences > 0.0)).float().sum()
total_total += key_padding_mask.float().sum()
e = loss(output, target)
e.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
optimizer.step()
total_loss += e.item()
if batch_i % args.log_interval == 0 and batch_i > 0:
curr_loss = total_loss / args.log_interval / bs
curr_accuracy = total_correct / total_total
total_loss = 0.
total_correct = 0
total_total = 0
# print the stats...
Please help - I’m at a loss (no pun intended). I feel like I’ve tried everything.