@ptrblck Thanks! I followed the link and it was helpful. However, I have another idea for the optimization of memory consumption. In this code, we save input, weight, bias
in the ctx
for backward pass computations, which is supposedly the source of huge memory consumption. Since we have already allocated these tensors elsewhere(e.g. weight
is th the torch.Parameter in a conv layer). I wonder how we can avoid this redundancy. At least we know that the native PyTorch implementation of Conv layer has already solved this problem. Do you know any alternates for this purpose?