I have a problem which I am absolutely baffled by and actually took almost all day to even pinpoint. I have a training loop as follows:
def train(model, params, train_dataset, epoch, optimizer, criterion):
clip = params.get('clip', 10)
log_interval = params.get('log_interval', 10)
batch_size = params.get('batch_size', 64)
sequence_length = params.get('bptt', 50)
with torch.enable_grad():
model.train()
total_loss = 0
start_time = time.time()
hidden_teacher = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_dataset.size(0) - 1, sequence_length)):
input_seq, target_seq = get_batch(sequence_length, train_dataset, i)
hidden_teacher = model.repackage_hidden(hidden_teacher)
hidden_freeflow = model.repackage_hidden(hidden_teacher)
optimizer.zero_grad()
loss = compute_loss(model, input_seq, target_seq, batch_size, hidden_freeflow, hidden_teacher, criterion)
loss.backward()
optimizer.step()
total_loss += loss.item()
The compute_loss function is defined as:
def compute_loss(model, input_seq, target_seq, batch_size, hidden_freeflow, hidden_teacher, criterion):
decoded = input_seq[0].unsqueeze(0)
decs = []
hids = []
# free flow loss
for _ in range(input_seq.size(0)):
decoded, hidden_freeflow, output = model.forward(decoded, hidden_freeflow)
decs.append(decoded)
hids.append(output)
out_seq1 = torch.cat(decs, dim=0)
hids = torch.cat(hids, dim=0)
loss1 = criterion(out_seq1.contiguous().view(batch_size, -1), target_seq.contiguous().view(batch_size, -1))
# Teacher forcing loss
out_seq2, hidden_teacher, output_2 = model.forward(input_seq, hidden_teacher)
loss2 = criterion(out_seq2.contiguous().view(batch_size, -1), target_seq.contiguous().view(batch_size, -1))
# Modified/weird professor loss
loss3 = criterion(hids.contiguous().view(batch_size, -1), output_2.contiguous().view(batch_size, -1).detach())
loss = loss1 + loss2 + loss3
return loss
Now, the puzzling thing is that if I instead copy the compute_loss
function and inline paste it in the main training loop, the results are different. It is not due to any random seed or anything like that. The difference really is due to this. I wonder if some state change happens due to this function call and if someone can shed light on this!
The difference is small. So here are the losses when the function is used:
1.018952488899231 1.0198941230773926 4.705279570771381e-05
0.9998740553855896 0.997614860534668 4.655172597267665e-05
0.9933467507362366 0.9878082275390625 4.792845720658079e-05
and without:
1.018952488899231 1.0198941230773926 4.705279570771381e-05
0.9998368620872498 0.997587263584137 4.6545839722966775e-05
0.9931793212890625 0.9876329302787781 4.7916713810991496e-05
Interestingly, the first call is always ok and the results agree but then diverge. Again, only difference is that the compute_loss
function is now inlined.