Help on basic class automatic diff usage

Hello everyone,
I am starting to build a physical simulation framework, that will need optimisation in the future.
It is porting a matlab set of routines that do mostly linear algebra. I have already ported one part in pytorch, directly replacing numpy functionality. Beeing able to send workflows to GPU is beyond great (100x speed up for our case). But now, I am refactoring to be able to do gradient descent on the objects and classes defined on the framework.
I want some help on understanding the best practices on how to do so. I am using a lot of code from Kornia and Kaolin as my problem is descriptible in 3D meshes.
I have some questions:

  • How to store class states that are goind to be diferentiated afterwards. A minimal example would be something like this:
class Foo:
    def __init__(self, x: Tensor):
        self.x = x
    def transform(self, p: Tensor):
        self.x = self.x + p
    def loss(self):
        return self.x.norm()**2

In this class, the sate is x, and I would transform this state using built-in methods as transform and would like to be able to compute a loss of his new transformed state. So I do this:

test = Foo(Tensor([1.,3.,5.]))
p = torch.tensor([-3., 1., 4. ], requires_grad=True)

And then a minimal training loop:

#train loop
learning_rate = 1e-5
for t in range(10):
    test.transform(p)
    loss = test.loss()
    loss.backward()
    print(f'loss: {loss}')
    with torch.no_grad():
        print(f' p:{p},\n grad {p.grad}')
        p -= learning_rate * p.grad      
        if p.grad.norm()<1e-4: 
            break
        p.grad.zero_()

But this is no working as expected, probably because I am overwriting the state x with a modified x+p in each state. In my case transformations are a little more complicated (rotations, scaling and translates).

  • How would you proceed? (nn.Module extension?)
  • Is there tutorials on this?
    Thanks you =)

Hi, welcome !

In general, to avoid this issue, we make (most) nn.Module stateless.

I think the problem you’re seeing is that if you run this for loop for a while, it will eat all your memory.
This is because, every new x you compute is given by x = transform(x, p).
At the very first iteration, x does not require gradient and p does.
So the backward sets the gradients of p and you’re happy.
But at the next iteration, the x that was the result of the op above does require gradients. And so at the second iteration, when you call backward, you will backward through both the second and the first iterations.

To prevent this, you want to break the link (in the autograd sense) between the x from one iteration to the other. You can do so with self.x = self.x.detach().

If you want to make sure that the graph of what is considered in the backward is what you expect, you can use a tool like torchviz to plot this graph.

Hope this helps !

It helps a lot!
You are completely right, I will take a look.

My current solution is modifying the class Foo as follows:

class Foo:
    def __init__(self, x: Tensor):
        self.x = x
    def transform(self, p: Tensor):
        self.x = self.x.detach()
        self.x = self.x + p
    def loss(self):
        return self.x.norm()**2
1 Like