Plotting epoch loss

I want to plot epoch loss curve, I’ve tried codes from Plotting loss curve but i’m getting errors like

TypeError: ‘DataLoader’ object is not subscriptable

train(args.epochs, args.batch_size, args.lr, args.num_classes)

This is my code:

def train(epochs, batch_size, learning_rate, num_classes):

    # fetch data
    train_loader, test_loader = get_data_loader(batch_size)

    # Loss and optimizer
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = LeNet(num_classes).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # start train
    total_step = len(train_loader)
    losses = []
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):

            
            # get image and label
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            epoch_loss = running_loss / len(train_loader['train'])
            losses.append(epoch_loss)

            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, epochs, i + 1, total_step, loss.item()))
                running_loss = 0.0

        # evaluate after epoch train
        evaluate(model, test_loader, device)

    # save the trained model
    save_model(model, save_path='lenet.pth')
    plt.plot(np.array(losses), 'r')
    return model

Could you show the code for get_data_loader? From the error TypeError: ‘DataLoader’ object is not subscriptable, it looks like there’s an issue with how you’re using your DataLoader. But it’s difficult to tell from just looking at your code. Could you also state at which line you are getting your errors?