I’m trying to find the rotation and translation which best maps one set of 3D points to another. My code:
reg_verts_tensor = torch.from_numpy(reg_verts).float().to(device)
verts_tensor = torch.from_numpy(verts).float().to(device)
tensor_0 = torch.zeros(1).float().to(device)
tensor_1 = torch.ones(1).float().to(device)
criterion = torch.nn.MSELoss(reduction='sum')
n_register_epochs = 1000
register_initial_learning_rate = 0.001
register_learning_rate = register_initial_learning_rate
optimizer = torch.optim.SGD([translation, rotation_x, rotation_y, rotation_z], lr=register_learning_rate)
loss_list = []
for epoch in range(n_register_epochs):
optimizer.zero_grad()
epoch_loss = None
transformed_verts = verts_tensor + translation
rotation_matrix_x = torch.stack([torch.stack([tensor_1,tensor_0,tensor_0]),
torch.stack([tensor_0,torch.cos(rotation_x), -torch.sin(rotation_x)]),
torch.stack([tensor_0, torch.sin(rotation_x), torch.cos(rotation_x)])]).reshape(3,3)
rotation_matrix_y = torch.stack([torch.stack([torch.cos(rotation_y), tensor_0, - torch.sin(rotation_y)]),
torch.stack([tensor_0, tensor_1, tensor_0]),
torch.stack([torch.sin(rotation_y), tensor_0, torch.cos(rotation_y)])]).reshape(3,3)
rotation_matrix_z = torch.stack([torch.stack([torch.cos(rotation_z), torch.sin(rotation_z), tensor_0]),
torch.stack([-torch.sin(rotation_z), torch.cos(rotation_z), tensor_0]),
torch.stack([tensor_0,tensor_0,tensor_1])]).reshape(3,3)
for i in range(transformed_verts.shape[0]):
transformed_verts[i] = torch.matmul(rotation_matrix_x, transformed_verts[i])
transformed_verts[i] = torch.matmul(rotation_matrix_y, transformed_verts[i])
transformed_verts[i] = torch.matmul(rotation_matrix_z, transformed_verts[i])
#
loss = criterion(transformed_verts, reg_verts_tensor)
epoch_loss = loss
if epoch % 10 == 0:
print(epoch_loss.item())
epoch_loss.backward()
This fails on the “backwards” call with:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [3]], which is output 0 of AsStridedBackward0, is at version 1323; expected version 1322 instead.
What am I doing wrong?