Use utils.checkpoint with DDP report an error

I use utils.checkpoint with DistributedDataParallel at the same time to save the CUDA memory.

When I use a single GPU to run my code, it works.

But if I use multi GPUs to run my code, it will report an error:

RuntimeError: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn’t able to locate the output tensors in the return value of your module’s forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 2: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 …
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

But I checked my code on single GPU, all params have their grad. I do not know what cause this error. It looks like all my params have no grad.

If I use use_reentrant=True, there will be the same error but with different params:

Parameter indices which did not receive grad for rank 1: 0 1 2 3 4 5 6 7 8 9 10 11 12
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
42 346 347 348 349 350 351 352 353 354 355 356 357 370 371 372 373 374 375 376 377 37
8 379 380 381 382 383 384 385

It makes me really confuse.

BTW, I really don’t understand what dose reentrant mean? If a part of my model’s outputs will be the inputs for next frame, should I use True or False for the argument use_reentrant?

Is it possible to provide a minimal script to reproduce your error?

If your model’s computational graph does not change from iteration to iteration, can you try passing static_graph=True to the DistributedDataParallel constructor?

My project is so huge, thus I can’t find a way to provide a small script. I will try it, but not easy.

My model is just like megvii-research/MOTR: [ECCV2022] MOTR: End-to-End Multiple-Object Tracking with TRansformer (github.com), in their project, they do not use the official checkpoint but csrhddlam/pytorch-checkpoint (github.com) instead, I don’t know why. But I want to use official checkpoint to make sure my code is more precise and maintainable.

And I can’t use static_graph=True because I have if condition in my code, for first frame some params will not be used, but in the second frame all params will be used. For an iter it will have at least two frame.

BTW, as above mentioned, I really don’t understand what dose reentrant mean? If a part of my model’s outputs will be the inputs for next frame, should I use True or False for the argument use_reentrant ?

Thanks for your reply~

Let me check if we have activation checkpointing support for non-static graph models when using DDP and get back to you.

Regarding reentrant, there are two versions of activation checkpointing implemented in PyTorch today: one is so-called “reentrant” and the other is “non-reentrant”. At a high level, the non-reentrant one (checkpoint(use_reentrant=False)) is supposed to add support for additional use cases that the reentrant cannot handle, but there are some gaps in both versions at the moment. Your case of having a model’s outputs be inputs for the next frame should not affect which version to use.

Truly appreciate for your reply ~

It makes my mind clear. For a long long long time, I misunderstood the reentrant and non-reentrant means two different models type (like LSTM and CNN). :face_with_head_bandage:

I checked and non-reentrant checkpoint should support DDP and non-static graph models. Could you try to pass use_reentrant=False and see if that works?

I’ve done this before, I meet the same error mentioned above:

RuntimeError: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel , and by
making sure all forward function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn’t able to locate the output tensors in the return value of your module’s forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 2: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 …
In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

The params indices are different from use_reentrant=True. But without DDP, on single GPU this code will work fine. I believe this means these params do have grad on single GPU.

I find a not elegant solution, my model just like the pseudo-code as follows:

class Model:
    def forward(x1, x2):
        feature = self.backbone(x1)
        [some other code]
        x2_tensor = modified_x2_to_tensor(x2)
        res = self.transformer(feature, x2_tensor)
        [some other code]
        return res

The x1 is a NestTensor (same as Deformable DETR), x2 is a custom object list. x2 is empty in the first frame, and will be a part of the res from the first frame’s outputs when input to the second frame (same model for different frame in an iteration).

If I use checkpoint like this:

model = Model()
......
x2 = init_x2()
for x1 in frames:
    res = torch.utils.checkpoint(model, x1, x2, use_reentrant=False)
    x2 = generate_x2(res)
.......

It will raise the above error.

But if I use it in the forward function like this:

class Model:
    def forward(x1, x2):
        feature = torch.utils.checkpoint(self.backbone, x1, use_reentrant=False)
        # feature = self.backbone(x1)
        [some other code]
        x2_tensor = modified_x2_to_tensor(x2)
        res = torch.utils.checkpoint(self.transformer, feature, x2_tensor, use_reentrant=False)
        # res = self.transformer(feature, x2)
        [some other code]
        return res

model = Model()
......
x2 = init_x2()
for x1 in frames:
    res = model(x1, x2)
    x2 = generate_x2(res)
.......

It works! But I don’t know why. Don’t know whether above code helps you understand the details about my case.

I don’t know if the above pseudo-code confused you, if so, I’m truly sorry that I can’t find a better way to show my code logic. :disappointed:

BTW, I will still try to make a demo to reproduce my case, but it may case some time because my code is complex and I have a lot of deadlines lately. :exploding_head:

Really really thank you for your patiently reply.

@HELLORPG the reason this change works is because you are checkpointing the wrapped DDP module (i.e. the local module) instead of the entire DDP wrapped module, this will change the forward pass that is re-run.

Yes, I think so. But I think this change does not change the code’s logic, just split a model into two modules. I don’t understand the difference between this two implementation. :rofl:

Hi, I do make a demo, not the same error raised, but this error is also what I met.

import torch
import torch.nn as nn
from torch.optim import SGD


class Tracks:
    def __init__(self):
        self.embed = torch.zeros((1, 0, 10), dtype=torch.float)


def select_tracks(ts: Tracks):
    ts_len = len(ts.embed)
    idx = int(ts_len * 0.5)
    ts.embed = ts.embed[idx:]
    return ts


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, padding=1)
        self.query = nn.Parameter(torch.randn(5, 10))
        self.attn = nn.MultiheadAttention(embed_dim=10, num_heads=1, batch_first=True)

    def forward(self, f, t: Tracks):
        feature = self.conv(f).flatten(2).transpose(1, 2)
        query = torch.cat((t.embed, self.query.repeat(1, 1, 1)), dim=1)
        r, _ = self.attn(query, feature, feature)
        return r


if __name__ == '__main__':
    image1 = torch.randn((1, 3, 10, 10))
    image2 = torch.randn((1, 3, 10, 10))
    images = [image1, image2]
    tracks = Tracks()
    model = Model()
    criterion = nn.L1Loss()
    optimizer = SGD(model.parameters(), lr=0.00001)
    loss = torch.zeros((1,), dtype=float)
    for frame in images:
        from torch.utils.checkpoint import checkpoint
        # res = model(frame, tracks)
        res = checkpoint(model, frame, tracks, use_reentrant=False)
        gt = torch.zeros(res.shape, dtype=torch.float)
        loss += criterion(res, gt)
        tracks.embed = res
        tracks = select_tracks(tracks)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print("Finish")

When I run this code, it will raise an error like this:

Traceback (most recent call last):
File “/home/gaoruopeng/Code/MeMOTR/demo.py”, line 59, in
loss.backward()
File “/home/gaoruopeng/anaconda3/envs/MeMOTR/lib/python3.10/site-packages/torch/_tensor.py”, line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File “/home/gaoruopeng/anaconda3/envs/MeMOTR/lib/python3.10/site-packages/torch/autograd/init.py”, line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x8 and 9x10)

It’s not the same error that we talked before, but if I fix this error, I think my code will work fine. Cause the error we talked above happends because I want to bypass this error.