I have a resnet50 model that outputs a class prediction (1, 2 or 3). Based on the output of the classifier, I want to make another prediction that selects the next layer/model based on the previous model output.
This is what I have so far.
import torch
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.model1 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model1.weight)
self.model2 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model2.weight)
self.model3 = torch.nn.Linear(1, 1, bias=False)
torch.nn.init.ones_(self.model3.weight)
def forward(self, x):
# Get batch_size
batch_size = x.size(1)
output = torch.zeros(batch_size, 1, device=x.device)
# Loop over every value in batch
for i in range(batch_size):
value = x[:, i]
if value == 1:
output[i] = self.model1(value)
elif value == 2:
output[i] = self.model2(value)
else:
output[i] = self.model3(value)
return output
model = SimpleModel()
output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output
tensor([[1.],
[2.],
[3.]], grad_fn=<CopySlices>)
My concern is that I am only computing one forward pass on each iteration of the loop which seems very inefficient. What happens if I increase the batch size to 64? Will the forward pass be computed in parallel?
Any thoughts/ideas would be appreciated.