PyTorch Gradients

I think a simpler way to do this would be:

num_epoch = 10
real_batchsize = 100 # I want to update weight every `real_batchsize`
for epoch in range(num_epoch):
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):

        data, target = Variable(data.cuda()), Variable(target.cuda())
        output = net(data)
        loss = crit(output, target)
        total_loss = total_loss + loss

        if batch_idx % real_batchsize == 0:

            ave_loss = total_loss/real_batchsize
            optimizer.zero_grad()
            ave_loss.backward()
            optimizer.step()
            total_loss = 0
4 Likes