Okay that seems like a better solution that what I came up with. Thanks for the solution. Glad to know assigning indices like that works. It seemed like it wouldn’t since I’ve had errors from setting values with masks etc.
Here’s the messy torch.cat solution in case anyone is interested.
def rotation_tensor(theta, phi, psi, n_comps):
one = Variable(torch.ones(n_comps, 1, 1))
zero = Variable(torch.zeros(n_comps, 1, 1))
rot_x = torch.cat((
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
), 2)
rot_y = torch.cat((
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
), 2)
rot_z = torch.cat((
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1)
), 2)
return torch.bmm(rot_z, torch.bmm(rot_y, rot_x))
n_comps = 5
theta = Variable(torch.zeros(n_comps, 1, 1), requires_grad=True)
phi = Variable(torch.zeros(n_comps, 1, 1), requires_grad=True)
psi = Variable(torch.zeros(n_comps, 1, 1), requires_grad=True)
rotation_mat = rotation_tensor(theta, phi, psi, n_comps)