Applying torch.matmul along custom dimension

For an implementation of a graphical model I need to perform matmul along the axis 1 and 2 of a four-dimensional input tensor. The axis 0 and 3 should be broadcasted. The current implementation of torch matmul performs the matrix multiplication across the final two axis and performs broadcasting across all of input[:-2].

Currently I solve this by first transposing the input and then performing matmul.

    input_row = input_col.transpose(1, 3)
    message = torch.matmul(self.gaussian, input_row)

However, this is very slow. Profiling indicate that 67% of computational time of the entire model is a contiguous call inside matmul (see below). The only reason this contiguous call is necessary is the transpose right before the matmul, as input_col is contiguous.

Does anybody know how I can perform this operation without transposing input_col?

In your case you can try:

message = torch.matmul(input_col, self.gaussian)