The efficiency of torch.cat

Hi, I implemented a network, which used torch.cat() to concatenate the input. when I train this network, I found it’s too slow. Is torch.cat inefficient, and are there any other ways to do the same?
The following is the code, thank you

class AdaptationNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(AdaptationNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1000, bias=False)
        self.fc1.weight = torch.nn.Parameter(torch.Tensor(1000, input_dim).uniform_(
            -np.sqrt(0.01 / (input_dim + 1000)), np.sqrt(0.01 / (input_dim + 1000))))
        self.fc1.bias = torch.nn.Parameter(torch.zeros(1000))
        self.bn1 = nn.BatchNorm1d(1000, momentum=0.05)
        self.act1 = nn.ReLU()
        #the size of speaker code is 50
        self.fc2 = nn.Linear(1050, 1000, bias=False)
        self.fc2.weight = torch.nn.Parameter(torch.Tensor(1000, 1050).uniform_(
            -np.sqrt(0.01 / (1000 + 1050)), np.sqrt(0.01 / (1000 + 1050))))
        self.fc2.bias = torch.nn.Parameter(torch.zeros(1000))
        self.bn2 = nn.BatchNorm1d(1000, momentum=0.05)
        self.act2 = nn.ReLU()
        self.fc3 = nn.Linear(1050, output_dim, bias=True)
        self.fc3.weight = torch.nn.Parameter(torch.Tensor(output_dim, 1050).uniform_(
            -np.sqrt(0.01 / (1050 + output_dim)), np.sqrt(0.01 / (1050 + output_dim))))
        self.fc3.bias = torch.nn.Parameter(torch.zeros(output_dim))

    def forward(self, speaker_codes, input):
        x = self.act1(self.bn1(self.fc1(torch.cat((speaker_codes, input),1))))
        x = self.act2(self.bn2(self.fc2(torch.cat((speaker_codes, x), 1))))
        x = self.fc3(torch.cat((speaker_codes, x), 1))
        return x

The following issue indicates that cat is slow on CPU: https://github.com/pytorch/pytorch/issues/18634

I found that a custom cat function is slightly faster on CPU:

def customCat(allTensors):
    totalSize = sum([t.shape[0] for t in allTensors]);
    newTensor = torch.zeros(totalSize, allTensors[0].shape[1], allTensors[0].shape[2]);
    counter = 0;
    for t in allTensors:
        newTensor[counter:counter+t.shape[0]] = t;
        counter += t.shape[0];
    return newTensor;

It definitely isn’t a generic replacement for cat, and I haven’t tested this code.

You also have the option of doing concatenations in the DataLoader via a user-defined collate_fn. DataLoader will be able to parallelize the cat operations across multiple batches. It may hide any latency you see.