Determinism of gradient accumulation

Am I correct in assuming that adding CPU operations like above makes gradient accumulation (for tensor shared_ctx) non-deterministic?

I.e. latent1backward & latent2backward are serialized, as single cuda stream is used. But CPU backward functions (and/or memory transfers) are executed concurrently, and summands for shared_ctx’s gradient can come in different orders.

Hi,

Yes you are correct that when using multiple devices, the accumulation is not forced to run in a specific order.

A “simple” solution to fix that is to use a custom function that does the copy to the different devices in the forward and the accumulation in the backward (in a fixed order in your custom backward). Would that work for you?

Thanks, at least this explains divergences I’m observing.

I’ll have to think, whether it is worth it to mitigate this with custom functions… One has to know all “client” nodes for that…

Good to be aware of this in general anyway.

or will it be enough to produce two tensors in autograd.Function: shared_ctx_cpu, shared_ctx_gpu

is latent3backward and latent4backward accumulation order deterministic?

You can for a very generic custom function to do that:

import torch
from torch.autograd import Function

class CustomScatter(Function):
    @staticmethod
    def forward(ctx, tensor, devices):
        ctx.inp_device = tensor.device
        outs = tuple(tensor.to(d) for d in devices)
        return outs

    @staticmethod
    def backward(ctx, *grads):
        res = 0.
        for g in grads:
            res += g.to(ctx.inp_device)
        return res, None

torch.manual_seed(123)
a = torch.rand(2, requires_grad=True)
devices = ["cpu", "cpu", "cuda", "cuda"]

all_t = CustomScatter.apply(a, devices)

loss = 0.
for t in all_t:
    loss += a * t.to("cpu")
loss.sum().backward()
print(a.grad)

is latent3backward and latent4backward accumulation order deterministic?

As long as they are “ready to run” (meaning that all their inputs are ready) at the same time and that they ran in the same order during the forward pass yes.

1 Like