How to fully optimize a custom convolution layer?

@ptrblck Thanks! I followed the link and it was helpful. However, I have another idea for the optimization of memory consumption. In this code, we save input, weight, bias in the ctx for backward pass computations, which is supposedly the source of huge memory consumption. Since we have already allocated these tensors elsewhere(e.g. weight is th the torch.Parameter in a conv layer). I wonder how we can avoid this redundancy. At least we know that the native PyTorch implementation of Conv layer has already solved this problem. Do you know any alternates for this purpose?