Multiple forward before backward, where backward depends on all forward calls

Well I’m not a super expert in PyTorch but I would recommend you to warp both forward passes in a nn.Module like this:

=================defining model=================
class My_model(nn.Module):
    def __init__(self):
        super(Audio_Synthesizer, self).__init__()
        self.model = your_original_model
   def forward(self, I, T):
        y_1 = self.model(I)
       y_2 = self.model(T*I)
       T_ = estimate_transformation(y_1, y_2)
mvc_loss = ||T*y_1 - y_2||^2  #Use torch loss/functions to do this
rot_loss = ||T - T_||^2   #Use torch loss/ functions to do this
loss =torch.add(mvc_loss , rot_loss)
       return T_,y_1,y_2
===========training stage========================



loss = my_custom_loss(T,T_,y_1,y_2)
optimizer.zero_grad()
loss.backward()
optim.step()

About the loss I would recommend you to read
[Solved] What is the correct way to implement custom loss function? - #4 by Hengck?

You can either build a custom loss function (upper link link)
or to include it inside the bigger nn

=================defining model=================
class My_model(nn.Module):
    def __init__(self):
        super(Audio_Synthesizer, self).__init__()
        self.model = your_original_model
   def forward(self, I, T):
        y_1 = self.model(I)
       y_2 = self.model(T*I)
       T_ = estimate_transformation(y_1, y_2)
       mvc_loss = ||T*y_1 - y_2||^2
       rot_loss = ||T - T_||^2
       loss =torch.add(mvc_loss , rot_loss)
       return loss, others
===========training stage========================
optimizer.zero_grad()
loss.sum() #if using dataparallel
loss.backward()
optim.step()

I don’t really know how it would affect to perform operations over the loss out of autograd class or nn.module class. Remember any operation you do must be done using autograd compatible functions (this is, native differentiable torch functions )

1 Like