Hi,
I’m working on some sort of image regression problem where I have a net that outputs a tensor “out” of size C x H x W and with a ground truth “target” of the same size.
Before computing the loss, I would like to apply a filter to both the output and the target with learnable weights (some smoothing). This would look like so (of course with all DataLoaders etc…) :
net = model.MyNet()
MyFilter = model.MyPostProcessingFilter()
img, target = Dataset[i]
criterion = nn.MSELoss()
out = net(img)
loss = criterion(MyFilter(out), MyFilter(target))
And MyFilter looks like this so :
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyPostProcessingFilter(nn.Modules):
def __init__(self):
super(MyPostProcessingFilter,self).__init__()
self.weights = nn.Parameters(some_init_weights)
def forward(self,x):
kernel = softmax2D(self.weights)
return F.conv2d(x, kernel)
The problem here is that applying this module to both output and target doesn’t work because “nn criterions don’t compute the gradeints w.r.t. targets”. I understand I can’t use parameters for both, so I tried to add the following method :
def target_fwd(self, x):
kernel = softmax2D(self.weights.detach())
return F.conv2d(x, kernel)
But this still doesn’t work and I get an error :
File "main.py", line 123, in main
train(train_loader, model, criterion, optim, epoch)
File "main.py", line 182, in train
loss.backward()
File "/home/bertrand/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/autograd/variable.py", line 146, in backward
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
File "/home/bertrand/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/nn/_functions/thnn/auto.py", line 49, in backward
grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
TypeError: view received an invalid combination of arguments - got (), but expected one of:
* (int ... size)
* (torch.Size size)
At some point I even got an illegal memory access but didn’t manage to reproduce it.
What’s the correct way to achieve this ? Many thanks