Hi, I want to create a Module to solve the rotation matrix by MSE loss.
In 3D space, the rotation matrix is determined by 6 variables: roll, yaw, pitch, dx, dy, dz.
So I want to set those 6 variables as nn.Parameters of my module, and use them to construct a rotation matrix. The forward pass just does a torch. matmul (rot_mat, coordinates). But I found that the gradient couldn’t be passed back to the parameters.
Anyone can help me with it?
Following is my code:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch import sin, cos
from torch import nn
model_points = np.array([
[34.375, 41.25, 16.875], # left eye
[90.625, 41.25, 16.875], # right eye
[62.5, 62.5, 0], # nose tip
[43.75, 81.25, 15.625], # left mouth
[81.25, 81.25, 15.625], # right mouth
[62.5, 103.75, -8.125] # chin
]).T
class RotationMat(nn.Module):
def __init__(self):
super(RotationMat, self).__init__()
self.roll = nn.Parameter(torch.FloatTensor([0.1]))
self.yaw = nn.Parameter(torch.FloatTensor([0.2]))
self.pitch = nn.Parameter(torch.FloatTensor([0.1]))
self.dx = nn.Parameter(torch.FloatTensor([0.2]))
self.dy = nn.Parameter(torch.FloatTensor([0.1]))
self.dz = nn.Parameter(torch.FloatTensor([0.2]))
self.mat = torch.Tensor([[cos(self.roll) * cos(self.yaw), -cos(self.yaw) * sin(self.roll), sin(self.yaw), self.dx],
[cos(self.pitch) * sin(self.roll) + cos(self.roll) * sin(self.pitch) * sin(self.yaw),
cos(self.pitch) * cos(self.roll) - sin(self.pitch) * sin(self.roll) * sin(self.yaw),
-cos(self.yaw) * sin(self.pitch), self.dy],
[sin(self.pitch) * sin(self.roll) - cos(self.pitch) * cos(self.roll) * sin(self.yaw),
cos(self.roll) * sin(self.pitch) + cos(self.pitch) * sin(self.roll) * sin(self.yaw),
cos(self.pitch) * cos(self.yaw), self.dz]]).requires_grad_()
def forward(self, x):
out = torch.matmul(self.mat, x)
return out
model_points_translate = model_points+10+np.random.rand(*model_points.shape)
model_points_translate = np.vstack((model_points_translate, np.ones((1, 6))))
model_points_translate = torch.Tensor(model_points_translate)
model_points = torch.Tensor(model_points)
model = RotationMat()
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1)
for i in range(100):
pred = model(model_points_translate)
loss = loss_func(pred, model_points)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1)%10 == 0:
print('Step: {}\t Loss: {:.4f}'.format(i+1, loss.item()))