import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly_1 = torch.nn.Linear(in_features=4, out_features=1)
self.ly_2 = torch.nn.Linear(in_features=4, out_features=1)
self.ly_3 = torch.nn.Linear(in_features=4, out_features=1)
def forward(self, x):
x_1 = self.ly_1(x[:,:4])
x_2 = self.ly_2(x[:,4:8])
x_3 = self.ly_3(x[:,8:])
return torch.cat((x_1, x_2, x_3), dim=1)
a = Model()
a(torch.rand(16,12)).shape
If you want to make sure that the model is working, you can check the gradient of the model using the following code.
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly_1 = torch.nn.Linear(in_features=4, out_features=1)
self.ly_2 = torch.nn.Linear(in_features=4, out_features=1)
self.ly_3 = torch.nn.Linear(in_features=4, out_features=1)
def forward(self, x):
x_1 = self.ly_1(x[:,:4])
x_2 = self.ly_2(x[:,4:8])
x_3 = self.ly_3(x[:,8:])
return torch.cat((x_1, x_2, x_3), dim=1)
a = Model()
optim = torch.optim.Adam(params=a.parameters())
loss_fn = torch.nn.MSELoss()
input_ = torch.rand(4,12)
label = torch.rand(4,3)
label[:, 2:3] = a(input_)[:, 2:3]
loss = loss_fn(a(input_), label)
optim.zero_grad()
loss.backward()
print(a.ly_1.weight.grad) # tensor([[-0.1005, -0.1412, -0.1536, -0.1327]])
print(a.ly_2.weight.grad) # tensor([[-0.2368, -0.3531, -0.3396, -0.2425]])
print(a.ly_3.weight.grad) # tensor([[0., 0., 0., 0.]])