Suppose I have a very simple ensemble of models like this:
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
ModelA does some simple preprocessing and feeds its output into ModelB. Usually the gradient is then calculated during the backwardpass starting from modelB and going to modelA.
I wonder, what will happen if I do something like this ?
def forward(self, x):
x = self.modelA(x)
x = torch.where(x > 0.05, 255, 0)
x = self.modelB(x)
return x
The toch.where() function here is supposed to produce some kind of threshold for the output of modelA. This code does not throw an error or warning. However since this function is not differentiable I was wondering what would happen during the backwardpass? Will the gradients in modelA be unchanged? Is the statement ignored and the loss just computes the gradient for the output of modelA?
I think the general question is: Are model gradients computed individually or in conjuction?
Hoping for some insights on this one.