Low GPU utilization for sequence task


(Klaus-Michael Lux) #1

Hi,

I implemented a neural model for text segmentation that feeds sequences of word embeddings into a bidirectional LSTM, extracts forward and backward outputs, feeds them into a fully connected layer separately and then computes the dot product between the outputs for each timestep, in line with with the ideas presented in this paper. The code is functional and I can overfit on training data, however GPU utilization is low (somewhere in the region of 40 % according to gpustat) and training’s generally quite slow. Could this be due to the use of for-loops for the computation of the time-step wise dot products? If so, what would be a more efficient way to compute them? Any help would be greatly appreciated.

    def forward(self, input):
        # apply dropout for regularization
        input = self.dropout(input)

        bi_output, bi_hidden = self.lstm(input)

        # iterate over batch elements
        n_elem = bi_output.size()[0]
        outs = []

        for i in range(n_elem):
            forward_output, backward_output = bi_output[i, :, :self.hidden_size], bi_output[i, :, self.hidden_size:]
            # apply FC layer to both
            forward_hidden = self.fc(forward_output)
            backward_hidden = self.fc(backward_output)

            products = []
            # compute dot product of the FC outputs, comparing at every time step
            n_steps = forward_hidden.size()[0]
            for step in range(n_steps):
                dot_product = torch.dot(forward_hidden[step], backward_hidden[step])
                products.append(dot_product.unsqueeze(0))

            output = torch.cat(products)
            outs.append(output)

        # reshape into batchwise format
        total_output = torch.stack(outs)
        return total_output

(Klaus-Michael Lux) #2

Based on this thread, I found a way to eliminate the inner for loop using bmm. Profiling indicates this has removed a lot of work from the CPU (especially the backwards pass) and has resulted in a considerable speedup.