Hi PyTorch lovers,
I’m facing an issue or an ununderstood of how the Pytorch autograd works.
I’m working with a big model where I’m only able to feed it image by image. Whatsoever, I would like to train it using a batch size greater than 1. Obviously, my image doesn’t have the same size. My data loader already push lists of image (My images are so big I need to split them and feed them one by one). If I understand well, the collate_fn function is not the solution for that (In my case, it will produce a list of lists, and not sure it helps a lot…). I read somewhere on the forum somebody talking about performing the forward on X images and perform the backward on the graphs constructed. Sadly, impossible to retrieve this post, sorry…
I proposed my solutions below (Dummy example) of my code and I would like to ask you if I handle properly the pipeline of X time forward-loss-backward-optimizer.step()-optimizer.zero_grad().
When I set the cfg[“Config”][“backwardEach”] to 1 (batch size of 1), It works, but when I use 4 or 8 my metrics fail and my models finish by inferring only 0 tensor array as output.
My model is split into a feature extractor(model1) and a segmentation model(model2). Otherwise, the code is very simplified because the one I use is made to handle different config but I think the below one should explain my idea =) If the one proposed here is correct I will probably give you my more complex implementation
The only thing that can have an impact is the losses. In the real case, I saved them iteratively into a dict and before the backward, I sum them.
loss = 0 backwardEach = 1 for epoch in range(cfg["Config"]["N_Epochs"]): for Input, Gt, in trainingset: #Iteration along one sample for In, gt in zip( Input, Gt ): #Iteration along patches from one sample In = In.to(cfg["Cuda"]["device"]) gt = gt.to(cfg["Cuda"]["device"]) features = model1(In) output = model2(features) loss += criterion(output, gt) if backwardEach / cfg["Config"]["backwardEach"] == 1: loss.backward() optimizer.step() optimizer.zero_grad() loss = 0 backwardEach = 1 else: backwardEach+=1