I’ve been banging my head against a wall trying to figure out the source of this error. Any help would be greatly appreciated!
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
def __init__(self, input_dim, embed_dim, n_heads, ff_dim, droprate, max_leng = 100):
super(Encoder, self).__init__()
self.tok_embed = nn.Embedding(input_dim, embed_dim)
self.pos_embed = nn.Embedding(max_leng, embed_dim)
self.norm = nn.LayerNorm(embed_dim)
self.encoder = nn.TransformerEncoderLayer(embed_dim, n_heads, ff_dim, droprate)
self.decoder = nn.Linear(embed_dim, max_leng)
self.dropout = nn.Dropout(droprate)
self.scale = torch.sqrt(torch.FloatTensor([embed_dim]))
def forward(self, input_ids, attn_mask):
batch_size = input_ids.shape[0]
input_len = input_ids.shape[1]
pos = torch.arange(0, input_len).unsqueeze(0).repeat(batch_size, 1)
input_ids = self.dropout(self.norm((self.tok_embed(input_ids) * self.scale) + self.pos_embed(pos)))
encoding = self.encoder(input_ids, src_mask=None, src_key_padding_mask = attn_mask)
output_ids = self.decoder(encoding)
return output_ids
model = Encoder(input_dim = len(tokenizer.vocab),
embed_dim = 256,
n_heads = 4,
ff_dim = 1024,
droprate = .5)
n_epochs = 50
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
mask_id = tokenizer.mask_token_id
ntokens = len(tokenizer.vocab)
criterion = nn.CrossEntropyLoss(label_smoothing = .2, ignore_index = pad_id)
optimizer = torch.optim.Adam(model.parameters(), lr = .0005)
def fit(model, train_iterator, train_dataset, optimizer, criterion):
print('Training')
model.train()
train_running_loss = 0.0
counter = 0
prog_bar = tqdm(enumerate(train_iterator), total = int(len(train_dataset) / train_iterator.batch_size))
for idx, batch in prog_bar:
counter += 1
optimizer.zero_grad()
input_ids, attn_mask = batch['input_ids'].clone(), batch['attention_mask']
attn_mask = torch.transpose(attn_mask, 0, 1).to(torch.bool)
rand_value = torch.rand(batch['input_ids'].shape)
rand_mask = (rand_value < .15) * (input_ids != bos_id) * (input_ids != eos_id) * (input_ids != pad_id)
mask_idx = (rand_mask.flatten() == True).nonzero().view(-1)
input_ids = input_ids.flatten()
input_ids[mask_idx] = mask_id
input_ids = input_ids.view(batch['input_ids'].size())
output = model(input_ids.to(device), attn_mask.to(device))
loss = criterion(output.view(-1), batch['input_ids'].view(-1).to(device))
train_running_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = train_running_loss / counter
return train_loss
def validate(model, val_iterator, val_dataset, criterion):
print('Validating')
model.eval()
val_running_loss = 0.0
counter = 0
prog_bar = tqdm(enumerate(val_iterator), total = int(len(val_dataset) / val_iterator.batch_size))
with torch.no_grad():
for idx, batch in prog_bar:
counter += 1
input_ids, attn_mask = batch['input_ids'].clone(), batch['attention_mask']
attn_mask = torch.transpose(attn_mask, 0, 1).to(torch.bool)
rand_value = torch.rand(batch['input_ids'].shape)
rand_mask = (rand_value < .15) * (input_ids != bos_id) * (input_ids != eos_id) * (input_ids != pad_id)
mask_idx = (rand_mask.flatten() == True).nonzero().view(-1)
input_ids = input_ids.flatten()
input_ids[mask_idx] = mask_id
input_ids = input_ids.view(batch['input_ids'].size())
output = model(input_ids.to(device), attn_mask.to(device))
loss = criterion(output.view(-1), batch['input_ids'].view(-1).to(device))
val_running_loss += loss.item()
val_loss = val_running_loss / counter
return val_loss
train_losses = []
val_losses = []
start = time.time()
for epoch in range(n_epochs):
print(f"Epoch {epoch + 1}/{n_epochs}")
train_epoch_loss = fit(model, train_dataloader, train_dataset, optimizer, criterion)
train_losses.append(train_epoch_loss)
val_epoch_loss = validate(model, val_dataloader, val_dataset, optimizer, criterion)
val_losses.append(val_epoch_loss)
early_stopping(val_epoch_loss)
if early_stopping.early_stop:
break
print(f"Training Loss: {train_epoch_loss:.5} \n Validation Loss: {val_epoch_loss:.5}")
end = time.time()
print(f"Training Time: {(end - start) / 60: .3f} minutes")