How to use torch.cuda.streams?

Q1:
In the pytorch,if we don’t use torch.cuda.streams explicitly,then pytorch only use one cuda stream(default cuda stream), am I right?

Q2:
I want to use multiple cuda stream,so different GPU tasks can be ran concurrently on a same GPU.I think this maybe improve the utilization rate of GPU.
Here is a cuda copy task “input_B.resize_(input_A.size()).copy_(input_A)”.input_B is torch.cuda.FloatTensor type,and input_A is torch.Tensor type.How to create a new cuda stream and put this cuda copy task to the new cuda stream?

Thanks a lot,and I only have one GPU in my computer.

2 Likes

torch.cuda.stream(your_stream) do what you want, I guess :wink:
And you can use multiprocess to launch the different process with different streams. Note that multiprocess with cuda is only suported in Python 3

Hi alexis-jacq. What is the argument we need to pass to torch.cuda.stream() ? I am new to streams so sorry for the stupid question. Thanks in advance too :slight_smile:

I am also trying to understand how to use streams. I believe that currently my model is getting much less out of the GPU than it could. (It is hard to understand where the bottlenecks are, but one tipoff is that nvidia-smi reports sm usage of only around 33%.)

A prior confusion that I have about pytorch before even getting to the topic of streams is about when pytorch is waiting for kernels to finish running. I have tried line profiling code using the python line profiler (https://github.com/rkern/line_profiler), and it seems like the numbers I get for how much time is spent on each line roughly correspond to how long I would expect the corresponding computation to take on the GPU (but, as noted below, simple operations are not as much faster than complex ones as I might expect, and a further caveat is that the line profiler doesn’t provide any kind of variance estimate). While seeing such numbers is good for helping me to see where bottlenecks are, it seems to imply that pytorch is waiting for computations to finish after each line (and not just when, say, I try to print out the result of a computation). And I would think that if pytorch just did an asynchronous kernel launch and immediately returned it would be faster.

If I am right that pytorch waits, that explains why my naive attempt to use streams below fails to improve performance.

What I tried doing was making a simple class to run code in parallel on different streams, like so:

class StreamSpreader():
    def __init__(self):
        self.streams = []

    def __call__(self, *tasks):
        if not torch.cuda.is_available():
            return [t() for t in tasks]
        while len(self.streams) < len(tasks):
            self.streams.append(torch.cuda.Stream())
        ret = []
        for s, t in zip(self.streams, tasks):
            with torch.cuda.stream(s):
                ret.append(t())
        return ret

One question I have is: if this implementation doesn’t work because pytorch will wait for each operation to complete before launching the next, is it at least possible to make a working StreamSpreader class with the same API?

I tried three implementations of an LSTM. In the first, I do separate matrix multiplies of the hidden state and the input and add the results. This is what is like pytorch does when not using cudnn, I believe. (By the way, if you are wondering why I don’t just use the built-in LSTM, it is because I actually want to use a somewhat different architecture that is not supported as a built-in.) In the second implementation, I use streams via my StreamSpreader class. In the third, I concatenate the hidden state and the input and do one matrix multiply. I found that that the last approach improved performance significantly but that the stream approach actually decreased performance slightly.

if mode == 'baseline':
    wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
    wi = torch.mm(input, self.weight_ih)
    preactivations = wh_b + wi
elif mode == 'multistream':
    wh_b, wi = self.spreader(
        lambda: torch.addmm(bias_batch, h_0, self.weight_hh),
        lambda: torch.mm(input, self.weight_ih),
    )
    preactivations = wh_b + wi
elif mode == 'fused':
    combined_inputs = torch.cat((h_0, input), 1)
    preactivations = torch.addmm(bias_batch, combined_inputs, self.combined_weights)

I was a bit surprised at the line-by-line timings though, which makes me wonder if I have indeed misunderstood the pytorch model and that the truth is that it does do something more async. In the baseline model, the two matrix multiples together took an average of 208 microseconds (121 + 87), and adding them to compute the preactivations took an average of 58 microseconds – it surprises me that the simple add operation takes so long. In the fused version, concatenating the hidden state and the input takes an average of 65 microseconds, but the single matrix multiply takes only 117 microseconds. For the stream version, adding the preactivations took an average of 53 microseconds, but doing the matrix multiplies took 262 microseconds.

This is probably obvious, but even if streams don’t make sense for the particular case of implementing an LSTM block, I am still really interested in learning to use them effectively – I just had to choose something as a test case, and this is what I chose.

I didn’t check if there were significant differences in backpropagation speed, by the way, but in general the model spends about half its time backpropagating. I guess a further confusion I may have depending on what the answer to the “prior” confusion is about how asynchronous backpropagation is and whether using multiple streams could be a good way to increase GPU utilization here.

By the way, I have also tried a little using the NVIDIA profiler. An issue I encountered is that my model segfaults after about 30 seconds when run under the profiler. Is there a known bug around this? I can get useful results by quitting the training before the segfault happens, but I am currently still having trouble understanding it due to unfamiliarity with the tool.

6 Likes