I’m working on some sort of image regression problem where I have a net that outputs a tensor “out” of size C x H x W and with a ground truth “target” of the same size.
Before computing the loss, I would like to apply a filter to both the output and the target with learnable weights (some smoothing). This would look like so (of course with all DataLoaders etc…) :
net = model.MyNet() MyFilter = model.MyPostProcessingFilter() img, target = Dataset[i] criterion = nn.MSELoss() out = net(img) loss = criterion(MyFilter(out), MyFilter(target))
And MyFilter looks like this so :
import torch import torch.nn as nn import torch.nn.functional as F class MyPostProcessingFilter(nn.Modules): def __init__(self): super(MyPostProcessingFilter,self).__init__() self.weights = nn.Parameters(some_init_weights) def forward(self,x): kernel = softmax2D(self.weights) return F.conv2d(x, kernel)
The problem here is that applying this module to both output and target doesn’t work because “nn criterions don’t compute the gradeints w.r.t. targets”. I understand I can’t use parameters for both, so I tried to add the following method :
def target_fwd(self, x): kernel = softmax2D(self.weights.detach()) return F.conv2d(x, kernel)
But this still doesn’t work and I get an error :
File "main.py", line 123, in main train(train_loader, model, criterion, optim, epoch) File "main.py", line 182, in train loss.backward() File "/home/bertrand/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/autograd/variable.py", line 146, in backward self._execution_engine.run_backward((self,), (gradient,), retain_variables) File "/home/bertrand/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/nn/_functions/thnn/auto.py", line 49, in backward grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim())) TypeError: view received an invalid combination of arguments - got (), but expected one of: * (int ... size) * (torch.Size size)
At some point I even got an illegal memory access but didn’t manage to reproduce it.
What’s the correct way to achieve this ? Many thanks