How to multiply with trainable tensor?

I have a model that outputs a tensor of shape [b,n,r,c], where b is the batch size. I want to multiply this tensor with a trainable tensor of shape [b,n,1,1] . I am confused about how to deal with batch size since we construct our models without knowing the batch size. How to go about it? How can I multiply a trainable tensor with a tensor of unknown shape

class myModel34(nn.Module):
    def __init__(self, features=None, num_classes=1000, **kwargs):
        super(myModel34, self).__init__()
        self.features = features 
        self.conv13 = nn.Conv2d(1024, num_classes, kernel_size=1)
        self.one_d = one_d_tensor(batch_size,num_classes)
   def forward(self,x):
        x = self.features(x)
        x = self.conv13(x)
        x = self.one_d(x)
        return x
   
    class one_d_tensor(nn.Module):
      def __init__(self, batch_size,num_classes):
        super(one_d_tensor,self).__init__()
        self.W = torch.nn.Parameter(torch.randn(batch_size,num_classes,1,1))
        self.W.requires_grad = True

      def forward(self,x):
        mul = torch.mul(x,self.W)
        return mul

Hi,

The main reason for “we construct our models without knowing the batch size” is because usually, the model should handle each element in the batch as if they were independent.
Here you actually want your model to behave differently for different elements in the batch, so this does not apply anymore and you’ll need to know the batch size beforehand.

1 Like

Well other than that, how should I multiply a trainable tensor with the output of a model? For, example, a linear combination without the summation

I’m not sure to understand the question. A regular multiplication will work. The fact that one of the Tensor is a Parameter does not change anything.

1 Like

Can u edit the above code and show me how to do multiplication with a trainable tensor in module? (I’m just confused how to do it)

I apologize for not asking the question correctly. I actually do not want my model to behave differently for different elements in the batch but behave differently for different classes

class one_d_tensor(nn.Module):
      def __init__(self, num_classes):
        super(one_d_tensor,self).__init__()
        self.W = torch.nn.Parameter(torch.ones(num_classes,1,1))
        self.W.requires_grad = True
      def forward(self,x):
        print("self.W: ",self.W)
        mul = torch.mul(x,self.W)
        return mul

You were right! I was confused about how to deal with the unknown batch size… Simple multiplication solved it for me.