Update parameters of a network with a loss that is not directly computed by that network


I have a network -> net1 whose output is used to compute a mask, M.

The input (In2) to another network -> net2 is multiplied by this mask and then passed into the net2.

In2 = In2 * M

How do I train net1 using the output of net2? I cannot train them sequentially beacause there is a function between net1 and net2 that uses the output of net1 to compute a mask.

net1 -> mask_compute -> In2*M -> net2

If mask_compute is not using PyTorch methods, which would thus detach the computation graph and Autograd won’t be able to create the backward pass automatically, you would have to implement a custom torch.autograd.Function for it and provide the forward and backward for these operations.
Here is an example.

@ptrblck, Thankyou for your response.

So the mask_compute function takes the output of net1 (size = Nx1) and creates a mask of size (Nx1) where the elements of the mask are 1 if the value of the corresponding element of the output of net1 is above a threshold and 0 if it’s below a threshold.

I am not sure how the forward and backward pass can be implemented for this?