I have to compute a function individually for the channels in the higher layer as follows:
# The forward pass output = torch.zeros(bsize, nchannel, 16, 16) for i in range(nchannel): output[:, i, :, :] = myfunc(input, i)
myfunc() is a complex function. I am able to allocate the memory in the first line, but in the middle of the loop I get
Out of memory error. I checked if I use a linear or residual layer and do all steps at once, I can compute the output layer with no problem.
Can you help me with the reason why I am getting out of memory error?