Parametrization moves the parameter to device, but not the gradient?

I am trying to use the new parametrization tool, talked about here: torch.nn.utils.parametrize.register_parametrization — PyTorch 1.12 documentation

More specifically I’m trying to use the ortogonal module: torch.nn.utils.parametrizations.orthogonal — PyTorch 1.12 documentation

I have created the following minimal example to show the problem:

import torch
import torch.nn as nn

class network_test(nn.Module):
    def __init__(self, low_dim, high_dim):
        super().__init__()

        self.lin = torch.nn.Linear(low_dim,high_dim)
        self.ortho = torch.nn.utils.parametrizations.orthogonal(self.lin)
        self.register_buffer("K", self.lin.weight)

    def forward(self,x):
        y = x @ self.K.T
        return y

low = 3
high = 6
batch_size = 10
model = network_test(low,high)
device = 'cuda:0'



model.to(device)
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

x = torch.randn(batch_size,low,device=device)
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()

This code will crash because the gradient of the parameter is somehow not on cuda, while the parameter itself is. I’m not sure how this is happening or what I should be doing differently?

When you perform this: self.register_buffer("K", self.lin.weight) you store in the buffer a tensor that is on CPU. This is also written in the docs of register_buffer: Module — PyTorch 1.12 documentation.

tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.`

To make your code work, just use any other parameter but the registered buffer K for your operations. You can call self.lin(x), self.ortho(x), or even use their weight.