Here are three equivalent code, with different runtime/memory comsumption.
Assume that you want to run sgd with a batch size of 100.
(I didn’t run the code below there might be some typos, sorry in advance)
1: single batch of 100 (least runtime, more memory)
# some code
# Initialize dataset with batch size 100
for input, target in dataset:
pred = net(input)
loss = crit(pred, target)
# one graph is created here
opt.zero_grad()
loss.backward()
# graph is cleared here
opt.step()
2: multiple small batches of 10 (more runtime, least memory)
# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
pred = net(input)
loss = crit(pred, target)
# one graph is created here
loss.backward()
# graph is cleared here
if (i+1)%10 == 0:
# every 10 iterations of batches of size 10
opt.step()
opt.zero_grad()
3: accumulate loss for multiple batches (more runtime, more memory)
# some code
# Initialize dataset with batch size 10
loss = 0
for i, (input, target) in enumerate(dataset):
pred = net(input)
current_loss = crit(pred, target)
# current graph is appended to existing graph
loss = loss + current_loss
if (i+1)%10 == 0:
# every 10 iterations of batches of size 10
opt.zero_grad()
loss.backward()
# huge graph is cleared here
opt.step()
It should be clear that case 3 is not what you want.
The choice between case 1 and 2 is a tradeoff between memory and speed so that depends on what you want to do.
Note that if you can fit a batch size of 50 in your memory, you can do a variation of case 2 with batch size of 50 and update every 2 iterations.