CUDA Streams & Gradient Manipulation

It appears PyTorch 1.10 has slightly changed the semantics of loss.backward() in conjunction with CUDA streams. The documentation is helpful but doesn’t precisely cover my use case, so I’d like to ensure I’m not introducing a subtle error. At a high level, I’d like to run something as follows:

class MyModule(nn.Module):
    def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
        with torch.cuda.stream(s1):
            out1 = self._module1(input1)
        with torch.cuda.stream(s2):
            out2 = self._module2(input2)
        torch.cuda.current_stream().wait_stream(s1)
        torch.cuda.current_stream().wait_stream(s2)
        return torch.cat([out1, out2], dim=1)

x1 = torch.rand(...)
x2 = torch.rand(...)
module = MyModule()
loss = module(x1, x2)
loss.backward()
clip_grad_norm_(...)

Given the backward pass through self._module1 and self._module2 will occur in their respective streams, is a synchronization primitive required prior to clipping the gradients? Would the answer change if one (but not both) of self._module1 or self._module2 were executed in the default stream?

There’s an unrelated issue with your example. You need to sync forward work streams with input-data streams before, as well as after, the side-stream work. record_stream calls on tensors used across streams also wouldn’t hurt.

class MyModule(nn.Module):
    def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
        s1.wait_stream(torch.cuda.current_stream())
        s2.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(s1):
            out1 = self._module1(input1)
        with torch.cuda.stream(s2):
            out2 = self._module2(input2)
        torch.cuda.current_stream().wait_stream(s1)
        torch.cuda.current_stream().wait_stream(s2)

        # good practice, ensures caching allocator safety of memory created
        # on one stream and used on another
        input1.record_stream(s1)
        input2.record_stream(s2)
        out1.record_stream(torch.cuda.current_stream())
        out2.record_stream(torch.cuda.current_stream())

        return torch.cat([out1, out2], dim=1)

x1 = torch.rand(...)
x2 = torch.rand(...)
module = MyModule()
loss = module(x1, x2)
loss.backward()
clip_grad_norm_(...)

The backward call itself is fine as-is. The essence of backward()'s internal stream handling is

The stream semantics of a backward call with respect to surrounding ops are the same as for any other call.

You have

loss = module(x1, x2)
loss.backward()
clip_grad_norm_(...)

in which loss computation and clip_grad_norm_ take place in the same ambient stream as the backward() call, so backward()'s consumption of loss and production of gradients interact with them safely. Internally, backward() will still run the backward ops of module1 and module2 on side streams, but the autograd engine inserts syncs such that backward()'s external interactions with anything on the same ambient stream are safe.

1 Like

Thanks for the clarification.