I have to compute a function individually for the channels in the higher layer as follows:
# The forward pass
output = torch.zeros(bsize, nchannel, 16, 16)
for i in range(nchannel):
output[:, i, :, :] = myfunc(input, i)
Where myfunc()
is a complex function. I am able to allocate the memory in the first line, but in the middle of the loop I get Out of memory error
. I checked if I use a linear or residual layer and do all steps at once, I can compute the output layer with no problem.
Can you help me with the reason why I am getting out of memory error?