Multiple forward before backward, where backward depends on all forward calls

I’d like to implement something similar to KeypointNets multi-view consistency and rotation loss. The gist of it is as follows:

Let M be our model, I an image and T a transformation matrix. For the multi-view consistency loss, one has a affine transformation matrix T and you’d like to impose equivariance upon the model wrt to T: M(TI) == TM(I)

Similarly for the rotation loss, given M(T*I) and M(I), one would like to estimate T with T_ and enforce T == T_.

The pseudo code I’ve implemented looks as follows:

y_1 = M(I)
y_2 = M(T*I)
mvc_loss = ||T*y_1 - y_2||^2
T_ = estimate_transformation(y_1, y_2)
rot_loss = ||T - T_||^2
loss = mvc_loss + rot_loss
loss.backward()
optim.step()

I’ve seen the threads Multiple forward before backward call, How to implement accumulated gradient in pytorch (i.e. iter_size in caffe prototxt) and How to implement accumulated gradient?, regarding calling multiple forwards before backpropagating. However, they all call backward() directly after each forward call. I need to call forward twice before being able to call backward, as the loss depends on both forward calls. I’ve been getting NaNs, and am wondering if I’m handling this correctly.

Can anybody confirm that what I’m doing is correct? If it isn’t could someone indicate what the correct way would be?

Thanks!

Anyone have an answer?

But what u wanna do is not difficult.
You have just to create an additional nn.Module class like:

 model = ur_model #nn.Module
class double_forward(nn.Module):


     init
        self.single=model
     forward(I,T):
        y_1 = M(I)
        y_2 = M(T*I)
       and so on

In the end u can have as many outputs as u want with ur custom loss. You can run the module as many times as u want (but define it only once) paying attention to the order not to create a wrong graph

Thank you for your response. So the code snippet which I posted is correct?

Well I’m not a super expert in PyTorch but I would recommend you to warp both forward passes in a nn.Module like this:

=================defining model=================
class My_model(nn.Module):
    def __init__(self):
        super(Audio_Synthesizer, self).__init__()
        self.model = your_original_model
   def forward(self, I, T):
        y_1 = self.model(I)
       y_2 = self.model(T*I)
       T_ = estimate_transformation(y_1, y_2)
mvc_loss = ||T*y_1 - y_2||^2  #Use torch loss/functions to do this
rot_loss = ||T - T_||^2   #Use torch loss/ functions to do this
loss =torch.add(mvc_loss , rot_loss)
       return T_,y_1,y_2
===========training stage========================



loss = my_custom_loss(T,T_,y_1,y_2)
optimizer.zero_grad()
loss.backward()
optim.step()

About the loss I would recommend you to read
[Solved] What is the correct way to implement custom loss function??

You can either build a custom loss function (upper link link)
or to include it inside the bigger nn

=================defining model=================
class My_model(nn.Module):
    def __init__(self):
        super(Audio_Synthesizer, self).__init__()
        self.model = your_original_model
   def forward(self, I, T):
        y_1 = self.model(I)
       y_2 = self.model(T*I)
       T_ = estimate_transformation(y_1, y_2)
       mvc_loss = ||T*y_1 - y_2||^2
       rot_loss = ||T - T_||^2
       loss =torch.add(mvc_loss , rot_loss)
       return loss, others
===========training stage========================
optimizer.zero_grad()
loss.sum() #if using dataparallel
loss.backward()
optim.step()

I don’t really know how it would affect to perform operations over the loss out of autograd class or nn.module class. Remember any operation you do must be done using autograd compatible functions (this is, native differentiable torch functions )

1 Like

Why would you recommend to wrap it in a custom class? Just for readability? It shouldnt affect the behaviour otherwise as far as I can see

I have the same question, have you managed to sort it out?

No, unfortunately not. I’ve since abandoned the project. Let me know if you figure it out.