I’m implementing a custom layer that needs multiple reshaping of two input tensors, THE BATCH IS OF THE SHAPE R^(N, a,b,c), where N is the number of tensors in the batch. Actually I need to reshape each 3D tensor with the size of R^(a,b,c) to be a matrix of the size (a*b,c) and then use torch.mm to multiply them. But I keep running into an error for the inconsistency of size (because torch.mm requires the inputs to be of 2D shape rather than 3D (in case of N>1) while custom layer reads the input as a four dimensional tensor (N,a,b,c). Any help will be highly appreciated.

Below you can find the code:

class mmm(torch.nn.Module):

def **init**(self):

super(mmm,self).**init**()

def forward(self, x, y):

xp=torch.reshape(x,(x.shape[0]*x.shape[1],x.shape[2]))

yp=torch.reshape(y,(y.shape[0]*y.shape[1],y.shape[2]))

out=torch.mm(xp,yp.permute(1,0))

return out

You can try using `torch.matmul`

.