Replicate operation tensor using `torch.tensordot()` and `torch.stack()`

I have a tensor operation that I would like to replicate using a combination of torch.stack() and torch.tensordot() to generalize it further on a larger program. In summary, I want to replicate the tensor V_1 using said operations into another tensor called V_2.

N, t , J  = 4, 2 , 3
K_f , K_r = 1, 1 
R = 5
K = K_f + K_r
id = torch.arange(N).repeat(t).sort()
X = torch.randn(N*t, K , J)
Y = torch.randn(N*t, 1)
D = torch.randn(N, K_r , R)
Draw = D.repeat_interleave(t,0) 
beta = torch.randn(2*K_r + K_f, 1)
beta_R = (beta[0:K_r,0] + beta[K_r:2*K_r,0] * Draw ).repeat(1,J,1)
print("shape beta_R:", beta_R.shape)
beta_F = beta[2*K_r:2*K_r + K_f,0].repeat(N*t, J, R)
print("shape beta_F:", beta_F.shape)
XX_0 =X[:,0,:].unsqueeze(2).repeat(1,1,R) 
print("shape XX_0:", XX_0.shape)
XX_1 =X[:,1,:].unsqueeze(2).repeat(1,1,R)
print("shape XX_1:", XX_1.shape)
V_1 = XX_0 * beta_R  + XX_1 * beta_F
print("shape V_1:",V_1.shape)
#shape beta_R: torch.Size([8, 3, 5])
#shape beta_F: torch.Size([8, 3, 5])
#shape XX_0: torch.Size([8, 3, 5])
#shape XX_1: torch.Size([8, 3, 5])
#shape V_1: torch.Size([8, 3, 5])

Now I want to do the same but stacking my tensors (using torch.stack()) and applying a generalized version of the dot-product (using torch.tensordot()), but I am a bit confused with the dims argument which is not doing what I expected.

#%% Replicating using stacking and tensordot
stack_XX = torch.stack((XX_0, XX_1), 0)
print("shape stack_XX:",stack_XX.shape)
stack_beta = torch.stack((beta_R, beta_F), 0)
print("shape stack_beta:", stack_beta.shape)
# dot product bewteen stack_XX and stack_beta along the first dimension
V_2 = torch.tensordot(stack_XX, stack_beta, dims=([0], [0]))
print("shape V_2:",V_2.shape)
# check if the two are equal
torch.all(V_1.eq(V_2))

#shape stack_XX: torch.Size([2, 8, 3, 5])
#shape stack_beta: torch.Size([2, 8, 3, 5])
#shape V_2: torch.Size([8, 3, 5, 8, 3, 5])
#tensor(False)

So I am basically trying to get tensor(True) when running torch.all(V_1.eq(V_2)).