Out of memory error in sequential processing

(Taha) #1

I have to compute a function individually for the channels in the higher layer as follows:

# The forward pass
output = torch.zeros(bsize, nchannel, 16, 16)
for i in range(nchannel):
    output[:, i, :, :] = myfunc(input, i)

Where myfunc() is a complex function. I am able to allocate the memory in the first line, but in the middle of the loop I get Out of memory error. I checked if I use a linear or residual layer and do all steps at once, I can compute the output layer with no problem.

Can you help me with the reason why I am getting out of memory error?