Hi. I just came from this post. I wanted to examine the values of the gradients to get a better understanding and see if they are mathematically the same for single large batch vs mini-batch accumulated gradients.
You can find my notebook here with results.
Batch size = 64
set_seed(0)
device = torch.device('cpu')
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold input and outputs
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)
learning_rate = 1e-6
loss_fn = torch.nn.MSELoss(reduction='sum')
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = loss_fn(y_pred, y)
loss.backward()
print(t, loss.item(), w2.grad[0,:3].tolist())
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
Results:
494 3.13658601953648e-05 [-0.005660085007548332, 0.008431014604866505, -0.008294136263430119]
495 3.106993972323835e-05 [-0.006205810233950615, 0.01971215382218361, -0.007334284484386444]
496 3.0749477446079254e-05 [-0.0073434882797300816, 0.015438897535204887, -0.00767811294645071]
497 3.0343362595885992e-05 [-0.008765282109379768, 0.013503124937415123, -0.003114561550319195]
498 2.9831506253685802e-05 [-0.006252425257116556, 0.015132719650864601, -0.014233914203941822]
499 2.9525415811804123e-05 [-0.009816620498895645, 0.01921480894088745, -0.008706835098564625]
Batch size=64, mini-batch size=8, number of mini-batch = 8
def batch_generator(tensor, bs):
while len(tensor) != 0:
yield tensor[:bs]
tensor = tensor[bs::]
set_seed(0)
device = torch.device('cpu')
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x_data = torch.randn(N, D_in, device=device)
y_data = torch.randn(N, D_out, device=device)
net_subdiv = 8
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-6
loss_scalar = 0
loss = 0
for t in range(500):
x_dl = batch_generator(x_data, int(N / net_subdiv))
y_dl = batch_generator(y_data, int(N / net_subdiv))
for i, (x,y) in enumerate(zip(x_dl, y_dl)):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = loss_fn(y_pred, y)
loss.backward()
loss_scalar += loss.item()
# Update for every 8 subdivisions --> 8 subdiv x 8 data pt per batch = 64 data pt
if ((i+1) % net_subdiv == 0):
print(t, loss_scalar, w2.grad[0,:3].tolist())
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
loss_scalar = 0
Results:
494 3.1384501994580205e-05 [-0.007196771912276745, 0.02287140302360058, -0.009920414537191391]
495 3.097124204032298e-05 [-0.010900281369686127, 0.02085154317319393, -0.007979532703757286]
496 3.064212341996608e-05 [-0.010124756954610348, 0.021325696259737015, -0.015906663611531258]
497 3.016368737007724e-05 [-0.009693929925560951, 0.02757120504975319, -0.007344169542193413]
498 2.9777340159853338e-05 [-0.006317904219031334, 0.01508003007620573, -0.014124374836683273]
499 2.9521241344809823e-05 [-0.00753612769767642, 0.026998091489076614, -0.010012011975049973]
The resulting gradients at the end seems to be slightly different. Did I do something wrong or is this within acceptable range of rounding errors?