Hi everybody,
I’ve trained a model on 3D data and now I’d like to add learnable rotation matrix and apply it just after taking new batch (with frozen model parameters).
This is my Rotation_Matrix class:
class Rotation_Matrix(nn.Module):
def __init__(self):
super(Rotation_Matrix, self).__init__()
self.a = nn.Parameter(torch.tensor(1.)) # init value of Rx cosine
self.b = nn.Parameter(torch.tensor(1.)) # init value of Ry cosine
self.c = nn.Parameter(torch.tensor(1.)) # init value of Rz cosine
def get_rotation_matrix(self):
Rx = torch.tensor([[1., 0., 0.],
[0., self.a, -torch.sqrt(1 - torch.pow(self.a,2))],
[0., torch.sqrt(1 - torch.pow(self.a,2)), self.a]])
Ry = torch.tensor([[self.b, 0., torch.sqrt(1 - torch.pow(self.b,2))],
[0., 1., 0.],
[-torch.sqrt(1 - torch.pow(self.b,2)), 0, self.b]])
Rz = torch.tensor([[self.c, -torch.sqrt(1 - torch.pow(self.c,2)), 0.],
[torch.sqrt(1 - torch.pow(self.c,2)), self.c, 0.],
[0., 0., 1.]])
return torch.mm(Rx, torch.mm(Ry, Rz))
def forward(self, x):
self.matrix = self.get_rotation_matrix()
return torch.mm(x, self.matrix)
Then initialisation and training of rotation matrix. I haven’t taken care of forcing rotation’s parameters to stay in [-1,1] interval just yet.
# load pre-trained model
model = model.load_state_dict(...)
model.eval()
rotation_matrix = Rotation_Matrix()
optimizer = torch.optim.Adam(rotation_matrix.parameters(), lr=l_rate)
rotation_matrix.train()
for i in range(n_epochs):
for j, x in enumerate(dataloader):
x = rotation_matrix(x)
x = model(x)
loss = loss_fun(x)
optimizer.zero_grad()
rotation_matrix.zero_grad()
loss.backward(retain_graph=True)
Printing grad of any parameter from rotation_matrix gives none value.
I read some related topics but I didn’t find solution. I tried using ‘retain_grad()’ on parameters before calling ‘loss.backward()’ and I played with ‘autograd.grad()’ function but with no results. What am I missing here? In addition, do I use ‘with torch.no_grad():’ on the ‘x = model(x)’ or it wont allow to backpropagate rotation’s parameters?
Thank you for all answers,
MS.