I saw the following bit of code on Github and was curious about the behavior of the return statement in the forward() function. Specifically I’m wondering why something is subtracted, detached, and then added back. Is this a standard technique or some hack to get the computation graph to behave a certain way?
import math
import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal
class DoE(nn.Module):
def __init__(self, dim, hidden, layers, pdf):
super(DoE, self).__init__()
self.qY = PDF(dim, pdf)
self.qY_X = ConditionalPDF(dim, hidden, layers, pdf)
def forward(self, X, Y, XY_package):
hY = self.qY(Y)
hY_X = self.qY_X(Y, X)
loss = hY + hY_X
mi_loss = hY_X - hY
return (mi_loss - loss).detach() + loss