The histc
function did not implemented backward
operation is because it is a discrete operation (I really don’t know how would you define that exactly). But the value could be added
to some other loss
. I tested with the following code.
Noted: the loss will be updated respect to lossB
below. And I am not sure how would you applied that to your problem
The setup
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
def forward(self, x):
x = self.conv(x)
x = self.sigmoid(x)
return x
model = TestModule()
optimizer = optim.Adam(model.parameters())
isAborted = True
def loop_stack(loss, acc):
global isAborted
if isAborted:
return
if loss == None:
print(list(reversed(list(map(lambda x: str(x)[1:-1].split(" ")[0], acc)))))
return
new_acc = acc[:] + [loss]
try:
losses_child = list(map(lambda x: x[0], loss.next_functions))
for l in losses_child:
loop_stack(l, new_acc)
if isAborted:
break
except KeyboardInterrupt:
isAborted = True
return
except:
print(list(reversed(list(map(lambda x: str(x)[1:-1].split(" ")[0], acc)))))
return
def print_backprop(loss):
global isAborted
tmp = loss.grad_fn
isAborted = False
loop_stack(tmp, [])
The execution:
source = torch.rand(1, 1, 5, 5)
target = model(source)
s = source.contiguous().view(-1)
t = target.contiguous().view(-1)
t_min = torch.min(torch.cat((s, t), 0)).item()
t_max = torch.max(torch.cat((s, t), 0)).item()
n_bins = 4
s_his = torch.histc(source, bins=n_bins, min=t_min, max=t_max)
t_his = torch.histc(target, bins=n_bins, min=t_min, max=t_max)
lossA = F.mse_loss(s_his.detach(), t_his.detach())
lossB = F.mse_loss(source, target)
optimizer.zero_grad()
loss = lossB/lossB.detach()*lossA #(lossB/lossB)*lossA
loss.backward(retain_graph=True)
optimizer.step()
print("Loss: {}\tBack prop path".format(loss))
print_backprop(loss)
print()
print("Before:\n{}\n\n{}\n=============".format(source, target))
print("After:\n{}\n\n{}".format(source, model(source)))
The final result showed that the weight is updated (somehow).
There are a few things I still don’t know
like is the weight updated with lossB
with magnitude of 1
or with magnitude of loss calculated with lossA
If I missed anything please reply or message me (I’m really curious)