Hi,
I’m trying to write loss function that maximize the mutual information between 2 images (the image that was created from a model and the target image).
the calculation of mutual information contains 2d histogram, and I get the error that there is no implementation for the derivative of that function.
is there anything that I can do to fix it?
this is the code for calculation the mutual information loss:
class MaxMutualLoss(th.nn.Module):
def init(self):
super(MaxMutualLoss, self).init()
def forward(self, inputs, targets):
th2d = th.stack((targets.ravel(),inputs.ravel()), -1)
hist_2d, _, = th.histogramdd(th2d.cpu(), bins=20)
pxy = hist_2d / float(th.sum(hist_2d))
px = th.sum(pxy, axis=1) # marginal for x over y
py = th.sum(pxy, axis=0) # marginal for y over x
px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals
# Now we can do the calculation using the pxy, px_py 2D arrays
nzs = pxy > 0 # Only non-zero pxy values contribute to the sum
return -1 * th.sum(pxy[nzs] * th.log(pxy[nzs] / px_py[nzs]))
the error that I’m getting when applying loss.backward():
RuntimeError: derivative for aten::_histogramdd_from_bin_cts is not implemented
Thank you!
Naomi