Hi, I implemented a network, which used torch.cat() to concatenate the input. when I train this network, I found it’s too slow. Is torch.cat inefficient, and are there any other ways to do the same?
The following is the code, thank you
class AdaptationNN(nn.Module):
def __init__(self, input_dim, output_dim):
super(AdaptationNN, self).__init__()
self.fc1 = nn.Linear(input_dim, 1000, bias=False)
self.fc1.weight = torch.nn.Parameter(torch.Tensor(1000, input_dim).uniform_(
-np.sqrt(0.01 / (input_dim + 1000)), np.sqrt(0.01 / (input_dim + 1000))))
self.fc1.bias = torch.nn.Parameter(torch.zeros(1000))
self.bn1 = nn.BatchNorm1d(1000, momentum=0.05)
self.act1 = nn.ReLU()
#the size of speaker code is 50
self.fc2 = nn.Linear(1050, 1000, bias=False)
self.fc2.weight = torch.nn.Parameter(torch.Tensor(1000, 1050).uniform_(
-np.sqrt(0.01 / (1000 + 1050)), np.sqrt(0.01 / (1000 + 1050))))
self.fc2.bias = torch.nn.Parameter(torch.zeros(1000))
self.bn2 = nn.BatchNorm1d(1000, momentum=0.05)
self.act2 = nn.ReLU()
self.fc3 = nn.Linear(1050, output_dim, bias=True)
self.fc3.weight = torch.nn.Parameter(torch.Tensor(output_dim, 1050).uniform_(
-np.sqrt(0.01 / (1050 + output_dim)), np.sqrt(0.01 / (1050 + output_dim))))
self.fc3.bias = torch.nn.Parameter(torch.zeros(output_dim))
def forward(self, speaker_codes, input):
x = self.act1(self.bn1(self.fc1(torch.cat((speaker_codes, input),1))))
x = self.act2(self.bn2(self.fc2(torch.cat((speaker_codes, x), 1))))
x = self.fc3(torch.cat((speaker_codes, x), 1))
return x