I have an array of N-Grams I’ve encoded from Beatles lyrics. However when I run it through my transformer model’s training function. I instantly run out of memory. What are some ways to overcome this and do I even have my batch function working properly?
Batch function:
def get_batches(arr_x, arr_y, batch_size): #takes two arrays of n-grams
# iterate through the arrays
prv = 0
for n in range(batch_size, arr_x.shape[0], batch_size):
x = arr_x[prv:n]
y = arr_y[prv:n]
prv = n
yield x, y
Train function:
import torch.optim as optim
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
src1= torch.tensor([[1,2,3,4,5,0],[3,4,5,6,7,8]]).to(device)
trg1=torch.tensor([[6,7,8,9,10,11],[8,9,10,11,12,13]]).to(device)
model=Transformer(src_vocab_size=vocab_size, trg_vocab_size=vocab_size, src_pad_idx=0, trg_pad_idx=0).to(device)
opt=torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn=nn.functional.cross_entropy
def train_model(src, trg, epochs=501, batch_size=32):
total_loss=0
train_loss_list, validation_loss_list = [], []
for e in range(epochs):
model.train()
for x, y in get_batches(src, trg, batch_size):
src, trg = torch.from_numpy(src.astype('int64')).to(device), torch.from_numpy(trg.astype('int64')).to(device)
pred=model(src, trg)
if e%50==0:
print("-"*25, f"Epoch {e + 1}","-"*25)
print(pred.shape)
print(torch.argmax(pred,dim=2))
print(trg)
loss=loss_fn(pred, torch.nn.functional.one_hot(trg, num_classes=14).type(torch.FloatTensor))
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.detach().item()/(e+1)
train_loss_list += [total_loss]
if e%50==0: print(f"Training loss: {total_loss:.4f}")
print('...')
return train_loss_list
print(model)
train_loss_list = train_model(src=x4_int,trg=y4_int)
Any help is appreciated.