The PyTorch distributed join context manager is not working properly when using model evaluation mode

The minimal code is as below:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn
import datetime

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE,
                            timeout=datetime.timedelta(seconds=30))
    torch.cuda.set_device(rank)
    model = torch.nn.Linear(1000, 1000)
    model.cuda()
    model = DDP(model)
    loss_fn = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    # Rank 1 gets one more input than rank 0
    num_inputs = 0
    for epoch in range(100):
        inputs = [torch.rand(1000, 1000).float().cuda() for _ in range(NUM_INPUTS + rank * int(1e2))]
        model.train()
        optimizer.zero_grad(set_to_none=True)
        ddp_loss = torch.zeros(2)
        with Join([model]):
            for input in inputs:
                num_inputs += 1
                model_output = model(input)
                loss = loss_fn(model_output, input)
                loss.backward()
                optimizer.step()
                ddp_loss[0] += loss.detach().item()
                ddp_loss[1] += 1
        print(f"EPOCH: {epoch} RANK {rank} train Loss: {ddp_loss[0] / ddp_loss[1]}")
        print(f"Rank {rank} has exhausted all {num_inputs} on epoch {epoch} of its inputs!")

        dist.barrier()
        model.eval()
        inputs = [torch.rand(1000, 1000).float().cuda() for _ in range(NUM_INPUTS + rank * int(1e3))]
        ddp_loss = torch.zeros(2)
        with torch.no_grad():
            # comment the follow Join([model]) still has the issue
            with Join([model]):
                for input in inputs:
                    num_inputs += 1
                    model_output = model(input)
                    loss = loss_fn(model_output, input)
                    ddp_loss[0] += loss.detach().item()
                    ddp_loss[1] += 1
        print(f"EPOCH: {epoch} RANK {rank} eval Loss: {ddp_loss[0] / ddp_loss[1]}")
        print(f"Rank {rank} has exhausted all {num_inputs} on epoch {epoch} of its inputs!")

    dist.destroy_process_group()

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES']="0,2"
    main()
[rank0]:[E907 22:08:38.496678785 ProcessGroupNCCL.cpp:607] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=337, OpType=ALLREDUCE, NumelIn=1001000, NumelOut=1001000, Timeout(ms)=60000) ran for 60062 milliseconds before timing out.
[rank0]:[E907 22:08:38.497269909 ProcessGroupNCCL.cpp:1664] [PG 0 (default_pg) Rank 0] Exception (either an error or timeout) detected by watchdog at work: 337, last enqueued NCCL work: 338, last completed NCCL work: 336.
[rank0]:[E907 22:08:38.497285879 ProcessGroupNCCL.cpp:1709] [PG 0 (default_pg) Rank 0] Timeout at NCCL work: 337, last enqueued NCCL work: 338, last completed NCCL work: 336.
[rank0]:[E907 22:08:38.497293841 ProcessGroupNCCL.cpp:621] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E907 22:08:38.497301295 ProcessGroupNCCL.cpp:627] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E907 22:08:38.498984868 ProcessGroupNCCL.cpp:1515] [PG 0 (default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=337, OpType=ALLREDUCE, NumelIn=1001000, NumelOut=1001000, Timeout(ms)=60000) ran for 60062 milliseconds before timing out.
Exception raised from checkTimeout at /opt/conda/conda-bld/pytorch_1720538439675/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f53bea16f86 in /home/chenguang.wan/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f53bfcfa0b2 in /home/chenguang.wan/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7f53bfd00af3 in /home/chenguang.wan/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f53bfd02edc in /home/chenguang.wan/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xdbbf4 (0x7f540f817bf4 in /home/chenguang.wan/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7f5425a7b1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7f5424f5de73 in /lib64/libc.so.6)

[rank1]:[E907 22:08:38.522399616 ProcessGroupNCCL.cpp:607] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=337, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=60000) ran for 60089 milliseconds before timing out.
[rank1]:[E907 22:08:38.522702696 ProcessGroupNCCL.cpp:1664] [PG 0 (default_pg) Rank 1] Exception (either an error or timeout) detected by watchdog at work: 337, last enqueued NCCL work: 1358, last completed NCCL work: 336.
[rank1]:[E907 22:08:38.522714916 ProcessGroupNCCL.cpp:1709] [PG 0 (default_pg) Rank 1] Timeout at NCCL work: 337, last enqueued NCCL work: 1358, last completed NCCL work: 336.
[rank1]:[E907 22:08:38.522719997 ProcessGroupNCCL.cpp:621] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E907 22:08:38.522726158 ProcessGroupNCCL.cpp:627] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E907 22:08:38.524018517 ProcessGroupNCCL.cpp:1515] [PG 0 (default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=337, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=60000) ran for 60089 milliseconds before timing out.
Exception raised from checkTimeout at /opt/conda/conda-bld/pytorch_1720538439675/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:609 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fb0adab5f86 in /home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7fb0aed990b2 in /home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7fb0aed9faf3 in /home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7fb0aeda1edc in /home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xdbbf4 (0x7fb0fe8b6bf4 in /home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #5: <unknown function> + 0x81ca (0x7fb114b1a1ca in /lib64/libpthread.so.0)
frame #6: clone + 0x43 (0x7fb113ffce73 in /lib64/libc.so.6)

W0907 22:08:39.935000 140180629394432 torch/multiprocessing/spawn.py:146] Terminating process 2551623 via signal SIGTERM
Traceback (most recent call last):
  File "/home/Papers/Test/async_ckpt/uneven_rank.py", line 65, in <module>
    main()
  File "/home//Papers/Test/async_ckpt/uneven_rank.py", line 61, in main
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
  File "/home//anaconda3/envs/torch/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home//anaconda3/envs/torch/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^f
  File "/home/anaconda3/envs/torch/lib/python3.12/site-packages/torch/multiprocessing/spawn.py", line 170, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT