[RESOLVED] Speeding up loss.backward() for a given network

Hi,

I have a scenario where my training data is huge (about 290 million rows). The problem is that loss.backward() and optimizer.step() operations together take about 4.7 seconds for batches of 1000 rows making the training process extremely slow.

The network is defined as follows:

class NETWORK(nn.Module):
    def __init__(self, word_embedding, entity_embedding, negative_size, text_len, dim_size, cpu, W=None, b=None):
        super(NETWORK, self).__init__()

        self.loss_function = nn.MSELoss()
        self.negative_size = negative_size
        self.text_len = text_len
        self.word_size = word_embedding.shape[0]
        self.entity_size = entity_embedding.shape[0]

        # Input to these embedding layers are long tensor of arbitrary (2000) shape
        self.word_embedding = Embedding(word_embedding.shape[0], word_embedding.shape[1], _weight=torch.from_numpy(word_embedding))
        self.entity_embedding = Embedding(entity_embedding.shape[0], entity_embedding.shape[1], _weight=torch.from_numpy(entity_embedding))

        if not cpu:
            self.W = nn.Parameter(torch.from_numpy(np.identity(dim_size)).float().cuda())
            self.b = nn.Parameter(torch.from_numpy(np.random.uniform(-0.05, 0.05, dim_size)).float().cuda())
        else:
            self.W = nn.Parameter(torch.from_numpy(np.identity(dim_size)).float())
            self.b = nn.Parameter(torch.from_numpy(np.random.uniform(-0.05, 0.05, dim_size)).float())

    # If batch_size is B, then text_input is a matrix of size Bx2000, entity_intput is a matrix of size Bx50 and labels is a matrix of size Bx31 (1 positive and 30 negative samples)
    def forward(self, text_input, entity_input, labels):
        def text_transform(i, embedding, text_input):
            mask = torch.ne(text_input, 0.0).float()
            vec = torch.matmul(mask, embedding)
            vec = vec / torch.norm(vec).float()
            return torch.matmul(vec, self.W) + self.b

        text_transformed_input = torch.stack([text_transform(torch.tensor(i), self.word_embedding(text_input[i]), text_input[i]) for (i, x) in enumerate(text_input)])
        similarity = torch.stack([torch.matmul(text_transformed_input[i], torch.transpose(self.entity_embedding(entity_input)[i], 0, 1)) for i in range(text_input.shape[0])])
        predictions = F.softmax(similarity)

        # return predictions
        return self.loss_function(predictions, Variable(labels.float())).unsqueeze(0)

Any advice of reducing computation time is highly appreciated. Also, as a parallel any advice on reducing memory usage on the input and the neural network that would help increase batch size is highly appreciated. I operate on a Nvidia 1080i gpu with 12gb gpu ram and 1000 batch size is the highest I can use without OOMing.

Thanks!

Do you need to use the whole dataset for training? Is there any reason you couldn’t just train on 1 million rows? Or simply just train for 1-2 epochs on the whole dataset. One epoch would take about 16 days.

It looks like you are manually performing a linear activation function. Is there any reason you can’t use build in pytorch tools to do that?

One other thing that might be worth considering is decreasing your batch size and doing parallel data loading.

Not sure on speeding up backwards, but

(1) as Andrew Plassard @aplassard notes, you should probably use F.linear instead of torch.matmul(vec, self.W) + self.b. Doing fused addmm is optimized for doing, well, matrix multiplication and addition whereas matmul ... + is not.

Then (2) define text_transform as its own method rather than a nested function. As is, that is defined every forward pass. In my experience, that can be a significant slowdown.

Then I also wonder if there’s a way to batch-vectorize your first list comprehension, text_transform(tor...). That could go a long way on forward and backward.

No. I do need the whole dataset for training and yes, I will be training it for one epoch only.

Thanks. The vectorization struck me as well. I have modified that and now the training is a bit faster but the biggest advantage was that I was able to increase my batch size by about 5x.