I am trying to use the new parametrization tool, talked about here: torch.nn.utils.parametrize.register_parametrization — PyTorch 1.12 documentation
More specifically I’m trying to use the ortogonal module: torch.nn.utils.parametrizations.orthogonal — PyTorch 1.12 documentation
I have created the following minimal example to show the problem:
import torch
import torch.nn as nn
class network_test(nn.Module):
def __init__(self, low_dim, high_dim):
super().__init__()
self.lin = torch.nn.Linear(low_dim,high_dim)
self.ortho = torch.nn.utils.parametrizations.orthogonal(self.lin)
self.register_buffer("K", self.lin.weight)
def forward(self,x):
y = x @ self.K.T
return y
low = 3
high = 6
batch_size = 10
model = network_test(low,high)
device = 'cuda:0'
model.to(device)
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
x = torch.randn(batch_size,low,device=device)
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
This code will crash because the gradient of the parameter is somehow not on cuda, while the parameter itself is. I’m not sure how this is happening or what I should be doing differently?