Hi PyTorch lovers,
I’m facing an issue or an ununderstood of how the Pytorch autograd works.
I’m working with a big model where I’m only able to feed it image by image. Whatsoever, I would like to train it using a batch size greater than 1. Obviously, my image doesn’t have the same size. My data loader already push lists of image (My images are so big I need to split them and feed them one by one). If I understand well, the collate_fn function is not the solution for that (In my case, it will produce a list of lists, and not sure it helps a lot…). I read somewhere on the forum somebody talking about performing the forward on X images and perform the backward on the graphs constructed. Sadly, impossible to retrieve this post, sorry…
I proposed my solutions below (Dummy example) of my code and I would like to ask you if I handle properly the pipeline of X time forward-loss-backward-optimizer.step()-optimizer.zero_grad().
When I set the cfg[“Config”][“backwardEach”] to 1 (batch size of 1), It works, but when I use 4 or 8 my metrics fail and my models finish by inferring only 0 tensor array as output.
My model is split into a feature extractor(model1) and a segmentation model(model2). Otherwise, the code is very simplified because the one I use is made to handle different config but I think the below one should explain my idea =) If the one proposed here is correct I will probably give you my more complex implementation
The only thing that can have an impact is the losses. In the real case, I saved them iteratively into a dict and before the backward, I sum them.
loss = 0
backwardEach = 1
for epoch in range(cfg["Config"]["N_Epochs"]):
for Input, Gt, in trainingset: #Iteration along one sample
for In, gt in zip( Input, Gt ): #Iteration along patches from one sample
In = In.to(cfg["Cuda"]["device"])
gt = gt.to(cfg["Cuda"]["device"])
features = model1(In)
output = model2(features)
loss += criterion(output, gt)
if backwardEach / cfg["Config"]["backwardEach"] == 1:
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss = 0
backwardEach = 1
else:
backwardEach+=1