OOM. Is my batch function not acting right?

I have an array of N-Grams I’ve encoded from Beatles lyrics. However when I run it through my transformer model’s training function. I instantly run out of memory. What are some ways to overcome this and do I even have my batch function working properly?

Batch function:

def get_batches(arr_x, arr_y, batch_size): #takes two arrays of n-grams
         
    # iterate through the arrays
    prv = 0
    for n in range(batch_size, arr_x.shape[0], batch_size):
        x = arr_x[prv:n]
        y = arr_y[prv:n]
        prv = n
        yield x, y

Train function:

import torch.optim as optim

device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

src1= torch.tensor([[1,2,3,4,5,0],[3,4,5,6,7,8]]).to(device)
trg1=torch.tensor([[6,7,8,9,10,11],[8,9,10,11,12,13]]).to(device)

model=Transformer(src_vocab_size=vocab_size, trg_vocab_size=vocab_size, src_pad_idx=0, trg_pad_idx=0).to(device)

opt=torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn=nn.functional.cross_entropy

def train_model(src, trg, epochs=501, batch_size=32):
    total_loss=0
    train_loss_list, validation_loss_list = [], []
    for e in range(epochs):
        model.train()
        for x, y in get_batches(src, trg, batch_size):
            src, trg = torch.from_numpy(src.astype('int64')).to(device), torch.from_numpy(trg.astype('int64')).to(device)
            pred=model(src, trg)
            if e%50==0:
                print("-"*25, f"Epoch {e + 1}","-"*25)
                print(pred.shape)
                print(torch.argmax(pred,dim=2))
                print(trg)
            loss=loss_fn(pred, torch.nn.functional.one_hot(trg, num_classes=14).type(torch.FloatTensor))
        
            opt.zero_grad()
            loss.backward()
            opt.step()
        
            total_loss += loss.detach().item()/(e+1)
            train_loss_list += [total_loss]

            if e%50==0: print(f"Training loss: {total_loss:.4f}")
            print('...')
        
    return train_loss_list
print(model)
train_loss_list = train_model(src=x4_int,trg=y4_int)

Any help is appreciated.

Fixed it. Batch and train function are now as follows:

def get_batches(arr_x, arr_y, batch_size):      
    # iterate through the arrays
    prv = 0
    for n in range(batch_size, arr_x.shape[0]+1, batch_size):
        x = arr_x[prv:n,:]
        y = arr_y[prv:n,:]
        prv = n
        yield x, y

model=Transformer(src_vocab_size=vocab_size, trg_vocab_size=vocab_size, src_pad_idx=0, trg_pad_idx=0).to(device)
opt=torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn=nn.functional.cross_entropy
    
def train_model(src, trg, epochs=51, batch_size=32, classes=vocab_size):
    print(model)
    
    train_loss_list, validation_loss_list = [], []
    for e in range(epochs):
        model.train()
        
        N=np.ceil(src.shape[0]/batch_size)
        running_loss=0
        correct=0
        total=0
        
        if e%5==0: print("-"*25, f"Epoch {e + 1}","-"*25)
        for x, y in get_batches(src, trg, batch_size):
            x, y = torch.from_numpy(x.astype('int64')).to(device), torch.from_numpy(y.astype('int64')).to(device)
            opt.zero_grad()
            pred=model(x, y)
            predicted=torch.argmax(pred,dim=2)
            loss=loss_fn(pred, torch.nn.functional.one_hot(y, num_classes=classes).type(torch.FloatTensor))
            loss.backward()
            opt.step()
            if e%5==0:
                print(pred.shape)
                print(predicted[0])
                print(y[0])

            running_loss += loss.detach().item()
            train_loss_list += [running_loss]
            total+=y.size(0)*y.size(1)
            correct+=predicted.eq(y).sum().item()
            accu=100.*correct/total
            
        if e%5==0: print(f"Training loss: {running_loss/N:.4f}, Accuracy: {accu:.2f}%")
        print('...')
        
    return train_loss_list
train_loss_list = train_model(src=x4_int,trg=y4_int)