Forward stuck in DistributedDataParallel training

Hello. I have trained a DDP model on one machine with two gpus.
DDP model hangs in forward at gpu:1 at second iteration. I debugged and turned out it was because of self.reducer._rebuild_buckets() function in torch/nn/modules/module.py. Is there anybody who helps me? thanks.
Torch version 1.8 cu+11 , 1.9 cu+11 both checked. neither works.
os : 5.4.0-81-generic #91~18.04.1-Ubuntu

below is my code and log printed

def run_train(train_fn, world_size, args):
    mp.spawn(train_fn,
             args=(args, ),
             nprocs=world_size,
             join=True)


def train_init(rank, train_info: TrainInfo):
    setup(rank, train_info.world_size)

    if rank == 0:
        tracking_server = "http://192.168.35.10:5010"
        mlflow.set_tracking_uri(tracking_server)
        mlflow.set_experiment('Drawing')
        now_time = datetime.now()

    output_path = train_info.output_path
    logger, _ = getLogger(f'drawing_{rank}th', out_dir=output_path)
    logger.info(f'pre-settings Initialized.. Rank: {rank}')
    logger.info(f'model loading.. Rank: {rank}')
    model = Drawing(train_info.backbone).to(rank)
    if train_info.model_weight_path is not None:
        model.load_state_dict(torch.load(train_info.model_weight_path))
    ddp_model = DDP(model, device_ids=[rank])

    num_workers = 0
    logger.info(f'data loading.. with {num_workers} workers Rank: {rank}')
    train_dataset = VerificationDataset(os.path.join(train_info.root_path, 'train'),
                                   input_size=train_info.input_size)
    val_dataset = VerificationDataset(os.path.join(train_info.root_path, 'val'), is_validation=True,
                                 input_size=train_info.input_size)
    item_dataset = ItemDataset(os.path.join(train_info.root_path, 'topN'), input_size=train_info.input_size)
    drawing_dataset = DrawingDataset(os.path.join(train_info.root_path, 'topN'), input_size=train_info.input_size)

    train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=rank)
    val_sampler = DistributedSampler(val_dataset, shuffle=False)
    item_sampler = DistributedSampler(item_dataset, shuffle=False)
    distributed_samplers = [train_sampler, val_sampler, item_sampler]

    train_dataloader = DataLoader(train_dataset, batch_size=train_info.batch_size, num_workers=num_workers,
                                  pin_memory=True, collate_fn=verification_collate_fn, sampler=train_sampler)
    val_dataloader = DataLoader(val_dataset, batch_size=train_info.batch_size, num_workers=num_workers,
                                pin_memory=True, collate_fn=verification_collate_fn, sampler=val_sampler)
    item_dataloader = DataLoader(item_dataset, batch_size=train_info.batch_size, num_workers=num_workers,
                                 pin_memory=True, sampler=item_sampler)
    drawing_dataloader = DataLoader(drawing_dataset, batch_size=train_info.batch_size, num_workers=num_workers,
                                    pin_memory=True)

    learning_rate = 0.001
    # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [100, 200], gamma=0.1)
    # lr_scheduler = None
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # weight = torch.tensor([0.1, 0.9]).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    params = {'run_name': train_info.run_name, 'backbone': train_info.backbone, 'data_root_path': train_info.root_path,
              'num_workers': num_workers, ' model_weight_path': train_info.model_weight_path,
              "init_lr": learning_rate, 'lr_scheduler': lr_scheduler, 'weight': 'no-weight',
              'device': rank, 'world_size': train_info.world_size, 'number_of_drawings': len(set(drawing_dataset.all_drawing_reg_nums)),
              'optim': f'{optimizer.__class__}', 'train_data_count': len(train_dataset),
              'val_data_count': len(val_dataset), 'vectors_batch_size': train_info.vectors_batch_size,
              'loss_func': f'{criterion.__class__}', 'input_size': train_info.input_size,
              'batch_size': train_info.batch_size,
              'epochs': train_info.epochs}
    logger.info(f'params >> {params}')
    result_path = os.path.join('results', output_path, f'{rank}th')
    if rank == 0:
        with mlflow.start_run(run_name=train_info.run_name):
            mlflow.log_params(params)
            e, v_acc, t3_acc, t1_acc = train(ddp_model, train_dataloader, val_dataloader, item_dataloader, drawing_dataloader, distributed_samplers, optimizer, criterion,
                  train_info.epochs,
                  rank, logger, vectors_batch_size=train_info.vectors_batch_size,
                  lr_scheduler=lr_scheduler, save_path=result_path, print_iter=train_info.print_iter)
            slack.postMessage(
                f"Start-time:{now_time} \t End-time:{datetime.now()} \t Elapsed time:{datetime.now() - now_time}s "
                f"\n {params} \n best_epoch:{e} \t 1/1 verification best_acc:{v_acc} \t top3 best_acc:{t3_acc} top1 best_acc:{t1_acc}")
    else:
        train(ddp_model, train_dataloader, val_dataloader, item_dataloader, drawing_dataloader, distributed_samplers,
              optimizer, criterion,
              train_info.epochs,
              rank, logger, vectors_batch_size=train_info.vectors_batch_size,
              lr_scheduler=lr_scheduler, save_path=result_path, print_iter=train_info.print_iter)
    cleanup()


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = '192.168.35.1'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(model, train_dataloader, val_dataloader, item_dataloader, drawing_dataloader, distributed_samplers, optim, criterion, epochs,
          device, logger, vectors_batch_size=1024,
          lr_scheduler=None, save_path='', print_iter=1):

    acc_max = 0.0
    top1_acc_max = 0
    top3_acc_max = 0

    best_model = None
    t_len = len(train_dataloader)
    v_len = len(val_dataloader)
    logger.info('<<<<<< Train Start >>>>>>>')
    save_path = Path(save_path)
    # model.to(device)

    if device == 0:
        writer = SummaryWriter(str(save_path / 'tensorboard'))
        writer.add_graph(model, [torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)])

    if not save_path.exists():
        Path(save_path).mkdir(parents=True)

    epoch_at_best = 0
    for e in range(epochs):
        for d_s in distributed_samplers:
            d_s.set_epoch(e)
        running_loss = 0.0
        global_step = 0
        model.train()
        ## Train
        for i, data in enumerate(train_dataloader):
            print(f"gpu{device} 1")
            img_x = data["imgs"]["data"].to(device)
            drawing_x = data["drawings"]["data"].to(device)
            y = data["is_sames"].to(device)

            print(f"gpu{device} 2")
            optim.zero_grad()
            print(f"gpu{device} 3")
            logit = model(img_x, drawing_x)   ### Hangs here in second iteration with gpu 1
            print(f"gpu{device} 4")
            loss = criterion(logit, y)
            print(f"gpu{device} 5")
            loss.backward()
            print(f"gpu{device} 6")
            optim.step()

            if lr_scheduler is not None:
                lr_scheduler.step()
            running_loss += loss.item()
            print(f"gpu{device} 7")
            if (i % print_iter) == 0:
                global_step = e * t_len + (i + 1)
                logger.info(f'\n Train device:{device} Epoch:{e} step[{global_step}/{epochs * t_len}] \t Train_loss avg : {running_loss/(i+1)}')
                if device == 0:
                    mlflow.log_metric('Train Loss', running_loss / (i+1), step=global_step)
                    # log_weights(model, writer, global_step)
class Drawing(torch.nn.Module):
	def __init__(self, backbone='resnet18', cls=1):
		super().__init__()
		if backbone.lower() == 'resnet18':
			self.imageModel = torch.nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-1])
			ml = models.resnet18(pretrained=True)
			self.classifier = torch.nn.Linear(512 * 2, 2)
		else:
			self.imageModel = torch.nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-1])
			ml = models.resnet50(pretrained=False)
			# ml = torch.load('pre-train-models/drawings/drawing-pretrained_90.pt')
			# ml.requires_grad_(False)
			if cls == 1:
				self.classifier = nn.Sequential(nn.Linear(2048 * 2, 1024),
				                                nn.BatchNorm1d(1024),
				                                nn.ReLU(),
				                                nn.Linear(1024, 2)
				                                )
			elif cls == 2:
				self.classifier = nn.Sequential(
					            nn.Dropout(p=0.5),
					            nn.Linear(2048 * 2, 1024),
					            nn.BatchNorm1d(1024),
					            nn.ReLU(),

					            nn.Dropout(p=0.5),
					            nn.Linear(1024, 512),
					            nn.BatchNorm1d(512),
					            nn.Sigmoid(),
					            nn.Dropout(p=0.5),

					            nn.Linear(512, 2)
	                            )
			else:
				raise ValueError(f'classification layer error: check cls value! cls:{cls} ')

		self.drawingModel = torch.nn.Sequential(*list(ml.children())[:-1])
		self.flatten = torch.nn.Flatten()
		# self.init_weights()

	def forward(self, img_x, drawing_x):
		img_x = self.imageModel(img_x)
		drawing_x = self.drawingModel(drawing_x)
		img_x = F.normalize(img_x, p=2, dim=1)
		drawing_x = F.normalize(drawing_x, p=2, dim=1)
		concat_x = torch.cat([img_x, drawing_x], dim=1)
		concat_x = self.flatten(concat_x)
		out = self.classifier(concat_x)
		return out

log output

<<<<<< Train Start >>>>>>>
INFO:drawing_1th:<<<<<< Train Start >>>>>>>
<<<<<< Train Start >>>>>>>
INFO:drawing_0th:<<<<<< Train Start >>>>>>>
gpu1 1
gpu1 2
gpu1 3
gpu1 4
gpu1 5
gpu1 6
gpu0 1
gpu0 2
gpu0 3
gpu0 4
gpu0 5
gpu0 6
gpu1 7
gpu0 7

 Train device:1 Epoch:0 step[1/507600] 	 Train_loss avg : 0.7977195978164673
INFO:drawing_1th:
 Train device:1 Epoch:0 step[1/507600] 	 Train_loss avg : 0.7977195978164673

 Train device:0 Epoch:0 step[1/507600] 	 Train_loss avg : 0.8086593747138977
INFO:drawing_0th:
 Train device:0 Epoch:0 step[1/507600] 	 Train_loss avg : 0.8086593747138977
gpu1 1
gpu1 2
gpu1 3
gpu0 1
gpu0 2
gpu0 3
gpu0 4
gpu0 5
gpu0 6

It was because of writer.add_graph…