Hello,
I’m trying to implement text generation RNN-based with sequence of different length with padding and masked crossed-entropy loss. Here a snippet of critical code.
Each backward step of the loss takes 30s against less than 2 for all above.
Thanks in advance for any suggestion !
import csv
import numpy as np
import logging
import time
import string
from itertools import chain
import torch
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from textloader import *
# from generate import *
import logging
logging.basicConfig(level=logging.INFO)
def maskedCrossEntropy(output, target, padcar):
mask = target != padcar
loss = torch.nn.CrossEntropyLoss(reduction="none")(output.permute(0,2,1), target.long()) * mask
return loss.sum() / mask.sum()
class RNN(nn.Module):
def __init__(self, latent, dim, out):
super().__init__()
self.latent = latent
self.dim = dim
self.out = out
self.hidden_state = torch.tensor(latent)
self.lin_hs = nn.Linear(latent, latent)
self.lin_ft = nn.Linear(dim, latent)
self.lin_dec = nn.Linear(latent, out)
def decode(self, hs):
d = self.lin_dec(hs)
return d
def forward(self, batch, hs):
l = []
for i in range(batch.shape[0]):
hs = self.one_step(batch[i,:], hs)
l.append(hs)
return torch.stack(l)
def one_step(self, batch, hs):
return torch.tanh(self.lin_hs(hs) + self.lin_ft(batch.clone()))
speech = ""
with open('data/full_speech.txt') as f:
while True:
c = f.read(1)
speech += c
if not c:
break
LR = 10e-3
SEQ_LEN = 100
PRED_LEN = 10
LATENT_DIM = 50
BATCH_SIZE = 500
EPOCH_RANGE = 5
embedding = nn.Embedding(len(id2lettre),50)
speech_dataset = TextDataset(speech)
speech_dataloader = DataLoader(speech_dataset, BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn)
model = RNN(LATENT_DIM, len(id2lettre), len(id2lettre))
optim = torch.optim.Adam(model.parameters(), lr = LR)
loss = torch.nn.CrossEntropyLoss()
hs = torch.zeros(BATCH_SIZE, LATENT_DIM)
for epoch in range(EPOCH_RANGE):
print(epoch)
i = 0
for x in speech_dataloader:
optim.zero_grad()
hst = model(embedding(x.long()), hs)
hst = model.decode(hst)
l = maskedCrossEntropy(hst, x, PAD_IX)
l.backward()
optim.step()