Performance improvement on Tree Lstm

Hi there,

I’m using Riddhiman Dasgupta 's pytorch implementation of tree lstm to my machine translation model:

I’m now having performance issue and would like to know if anyone has any improvement idea.
The code is as below, what this node_forward does is to process one token in a sentence.

def node_forward(self, inputs, child_c, child_h):

    child_h_sum = torch.sum(child_h, dim=0, keepdim=True)
    iou = self.ioux(inputs) + self.iouh(child_h_sum)
    i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
    i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

    f = F.sigmoid(
        self.fh(child_h) +
        self.fx(inputs).repeat(len(child_h), 1)
    )
    fc = torch.mul(f, child_c)

    c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
    h = torch.mul(o, F.tanh(c))
    return c, h

In my case,
inputs is a 300 dimensions tensor
child_c and child_h are tensors with size [1,300]
ioux and iouh are nn.Linear with 300 in_features and 900 out_features
fh and fx are nn.Linear with 300 in_features and 300 out_features

I’m running it on my GTX 1070, but I also tried GTX 1080Ti which doesn’t have a significant improvement.

The node_forward method takes 0.00135s. In my translation mode, it needs to do encoding 4 times in each iteration. So if I want to train it for 300k iteration with 50 lines for each iteration. Let’s say all sentences have 50 tokens. It will take:

4 times * 300k iteration * 50 lines/iteration * 50 tokens/lines * 0.00135s/token = 4,050,000s = 46.875 days

If we look into each line:
iou = self.ioux(inputs) + self.iouh(child_h_sum) ----- 0.0003s (22.22% of the method time)

i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u) ----- 0.00014s (10% of the method time)

f = F.sigmoid(self.fh(child_h) +self.fx(inputs).repeat(len(child_h), 1)) ----- 0.00043s (31.85% of the method time)

c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True) ----- 0.00014s (10.37% of the method time)

The overall performance of this tree-LSTM is 10 times slower than the native Pytorch LSTM cell. The extra tree structure can contribute to the slowness but I’m really not sure if it can be 10 times slower because of this.
I checked the implementation of native Pytorch LSTM and GRU cell, it seems that Pytorch is using some CuDnn api to do this kind of LSTM or GRU. So I’m wondering whether I can use that api to speed up my training though I can’t find any document.

On the other hand, I’m not sure if I can do anything to speed up those torch methods in the lines I listed.

Any input is much appreciated! Thanks in advance!

The problem is likely the (lack of) batching.
James Bradbury wrote an article how to do that to speed up things (but it’s PyTorch 0.2 or so).

Best regards

Thomas

1 Like