# How to multiply with trainable tensor?

I have a model that outputs a tensor of shape `[b,n,r,c]`, where `b` is the batch size. I want to multiply this tensor with a trainable tensor of shape `[b,n,1,1]` . I am confused about how to deal with batch size since we construct our models without knowing the batch size. How to go about it? How can I multiply a trainable tensor with a tensor of unknown shape

``````class myModel34(nn.Module):
def __init__(self, features=None, num_classes=1000, **kwargs):
super(myModel34, self).__init__()
self.features = features
self.conv13 = nn.Conv2d(1024, num_classes, kernel_size=1)
self.one_d = one_d_tensor(batch_size,num_classes)
def forward(self,x):
x = self.features(x)
x = self.conv13(x)
x = self.one_d(x)
return x

class one_d_tensor(nn.Module):
def __init__(self, batch_size,num_classes):
super(one_d_tensor,self).__init__()
self.W = torch.nn.Parameter(torch.randn(batch_size,num_classes,1,1))

def forward(self,x):
mul = torch.mul(x,self.W)
return mul
``````

Hi,

The main reason for â€śwe construct our models without knowing the batch sizeâ€ť is because usually, the model should handle each element in the batch as if they were independent.
Here you actually want your model to behave differently for different elements in the batch, so this does not apply anymore and youâ€™ll need to know the batch size beforehand.

1 Like

Well other than that, how should I multiply a trainable tensor with the output of a model? For, example, a linear combination without the summation

Iâ€™m not sure to understand the question. A regular multiplication will work. The fact that one of the Tensor is a Parameter does not change anything.

1 Like

Can u edit the above code and show me how to do multiplication with a trainable tensor in module? (Iâ€™m just confused how to do it)

I apologize for not asking the question correctly. I actually do not want my model to behave differently for different elements in the batch but behave differently for different classes

``````class one_d_tensor(nn.Module):
def __init__(self, num_classes):
super(one_d_tensor,self).__init__()
self.W = torch.nn.Parameter(torch.ones(num_classes,1,1))