Batch-wise complex operation

I’ve got a 4 dimensional torch tensor parameter defined like this :

nn.parameter.Parameter(data=torch.Tensor((13,13,13,13)), requires_grad=True)

and four tensors with dims (batch_size,13) (or one tensor with dims (batch_size,4,13)). I’d like to get a tensor with dims (batch_size) equal to the formula at the end of this picture :

If A is a tensor of dims 3, then I manage to do it with :


But if A is a tensor of dims 4, I have no idea of how to perform it with torch functions.