Batch-wise complex operation

I’ve got a 4 dimensional torch tensor parameter defined like this :

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

and four tensors with dims (batch_size,13) (or one tensor with dims (batch_size,4,13)). I’d like to get a tensor with dims (batch_size) equal to the formula at the end of this picture :

If A is a tensor of dims 3, then I manage to do it with :

torch.bmm(torch.unsqueeze(z,2),torch.bmm(torch.unsqueeze(y,1),torch.transpose(torch.matmul(x,A),0,1))).sum(axis=2).sum(axis=1)

But if A is a tensor of dims 4, I have no idea of how to perform it with torch functions.