Hi i have neural network to classify an image as eye or not.
I am training the neural network with batch size 32, for some iterations say about 70 the model is giving predictions faster but 98th iteration started at 00:48:18, still this iteration is finished.
Could anyone please tell me why the network is taking so much time.
def train(self):
# print(“Training the batch…”)
# Ground truth
indexes = torch.tensor([self.samples.index(obj) for obj in self.samples if obj.dtype==“pos”])
gt_imgs = [obj.patch for obj in self.samples if obj.dtype==“pos”]
cls_label = torch.tensor([1 if obj.dtype==“pos” else 0 for obj in self.samples])
bbox_gt = torch.FloatTensor([self.PosSet.pos_samples[obj.id].patch_bbox for obj in self.samples if obj.dtype==“pos”])
bbox_gt = torch.reshape(bbox_gt,(-1,BBOX)) # TODO: BS // convert 4 to derived macro; or compute from existing macros
lm_gt = torch.FloatTensor([self.PosSet.pos_samples[obj.id].patch_lm.flatten() for obj in self.samples if obj.dtype==“pos”])
lm_gt = torch.reshape(lm_gt,(-1,LM)) # TODO: BS // Why 12? Derive using existing macros and variables.
# Training
image_batch = torch.FloatTensor([obj.patch for obj in self.samples])
image_batch = image_batch.reshape(BATCH_SIZE,3,TEMPLATE_HGT,TEMPLATE_WDT)
pred = model(image_batch)
optimizer.zero_grad()
# Predictions
cls_pred = pred[0]
bbox_pred = torch.index_select(pred[1],0,indexes)
lm_pred = torch.index_select(pred[2],0,indexes)
# Loss
cls_loss = criterion_cls(cls_pred,cls_label)
bbox_loss = criterion_bbox(bbox_pred,bbox_gt)
bbox_loss = torch.mean(bbox_loss,1)
lm_loss = criterion_lm(lm_pred,lm_gt)
lm_loss = torch.mean(lm_loss,1)
# Sum up the loss
nsamples = len([obj for obj in self.samples if obj.dtype=="neg"])
dummy_loss = torch.zeros(nsamples)
bbox_loss = torch.cat((bbox_loss,dummy_loss))
lm_loss = torch.cat((lm_loss,dummy_loss))
loss = A * cls_loss + B * bbox_loss + (1 - A - B) * lm_loss
loss.sum().backward()
optimizer.step()
return loss
return loss I think this is the problem. Can you try with return loss.detach().
Problems like these are generally caused when you training loop is holding on to some things that it shouldn’t. Also, make sure that you are not storing some temporary computations in an ever growing list without deleting them.
To track this down, you can also get timings for different parts separately: data loading, network forward, loss computation, backward pass and parameter update. Hopefully just one will increase and you will be able to see better what is going on.