[NEED HELP] RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one

I have used find_unused_parameters=True, but it still reports this error.

the model structure is that i implemented a model (model1) to learn weights (or mask) for the output of Resnet model (model2), use margin softmax loss for classification. so model2 have called model1. the model1 use mse loss for feature similarity comparison.

the model1 is a siamese network, so it has two inputs. img1 and img2.

at first, it reports input/output size not consistent problem, but finally after time consuming modification of input/output size, i no longer reports that.

but it reports the unused parameters problem… i have no idea what i should do

the complete error report:

Traceback (most recent call last):
  File "train.py", line 159, in <module>
    main(args_)
  File "train.py", line 111, in main
    f_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Traceback (most recent call last):
  File "train.py", line 159, in <module>
    main(args_)
  File "train.py", line 111, in main
    f_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Traceback (most recent call last):
  File "train.py", line 159, in <module>
    main(args_)
  File "train.py", line 111, in main
    f_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Traceback (most recent call last):
  File "train.py", line 159, in <module>
    main(args_)
  File "train.py", line 111, in main
    f_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 606, in forward
    if self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
terminate called without an active exception
Traceback (most recent call last):
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/distributed/launch.py", line 260, in <module>
    main()
  File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/distributed/launch.py", line 256, in main
    cmd=cmd)
subprocess.CalledProcessError: Command '['/home/user1/miniconda3/envs/py377/bin/python3', '-u', 'train.py', '--local_rank=3']' died with <Signals.SIGABRT: 6>.

the weight model (model1) structure:


class MODEL1(nn.Module):
    def __init__(self,network,embedding_size,batch_size,dropout,fp16):
        super(MODEL1, self).__init__()
        self.batch_size = batch_size
        self.resnet = eval(network)(pretrained=False, num_features=embedding_size, dropout=dropout, fp16=fp16)
        self.features_shape = embedding_size

        # mask generator
        self.sia = nn.Sequential(
            # nn.BatchNorm2d(filter_list[4]),
            # conv1x1(self.inplanes, planes * block.expansion, stride),
            nn.Conv2d(self.features_shape, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PReLU(self.features_shape),
            nn.BatchNorm2d(self.features_shape),
            nn.Sigmoid(),
        )
        self.fcMG = nn.Sequential(
            # nn.BatchNorm1d(self.features_shape * 7 * 7),
            nn.BatchNorm1d(self.features_shape),
            # nn.Dropout(p=0),
            # nn.Linear(self.features_shape * 7 * 7, self.features_shape),
            nn.Linear(self.features_shape, self.features_shape),

            nn.BatchNorm1d(self.features_shape),
        )

        # Weight initialization
        for m in self.modules():
            if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
        elif (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)):
            nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

    def getFeatures(self,batch):
        return self.resnet.getEmbedding(batch)

    def forward(self,soruce,target):
        # MG
        f_clean = self.getFeatures(soruce)
        f_occ = self.getFeatures(target)

        f_diff = torch.add(f_clean, f_occ, alpha=-1.0)
        f_diff = torch.abs(f_diff) # [batch_size, 25088]

        # f_diff shape should be 4d tensor
        f_diff = f_diff.unsqueeze(2).unsqueeze(3) # (batch_size, 512, 1, 1);
        # f_diff = f_diff.reshape(self.batch_size, self.features_shape, 7, 7)

        mask = self.sia(f_diff) # (batch_size, 512, 1, 1)
        # End Siamese branch

        mask = mask.reshape(self.batch_size, -1)

        f_clean_masked = f_clean * mask # [batch_size, 512, batch_size, 512]
        f_occ_masked = f_occ * mask

        fc = f_clean_masked.view(f_clean_masked.size(0), -1)  # 256*(512*7*6)
        fc_occ = f_occ_masked.view(f_occ_masked.size(0), -1) # (batch_size, 1048576),(batch_size, 512)

        fc = self.fcMG(fc) # expect input: 512 * 7 * 7, 25088
        fc_occ = self.fcMG(fc_occ)

        return f_clean_masked, f_occ_masked, fc, fc_occ

the resnet (model2) structure:



def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, num_features, dropout=0, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
        super(IResNet, self).__init__()
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)

        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.prelu(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.bn2(x)
            x = torch.flatten(x, 1)
            x = self.dropout(x) # [128, 25088]

        x = self.fc(x.float() if self.fp16 else x) # [128, 512]
        x = self.features(x)
        # axis = 1
        # norm = torch.norm(x,2,axis,True)
        # x = torch.div(x,norm)
        return x

    def getEmbedding(self,batch):
        features = self.forward(batch)
        return features

i have used the distributed computing in pytorch, the train.py:


def main(args):
    world_size = int(os.environ['WORLD_SIZE'])
    rank = int(os.environ['RANK'])
    dist_url = "tcp://{}:{}".format(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"])
    dist.init_process_group(backend='nccl', init_method=dist_url, rank=rank, world_size=world_size)
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)

    if not os.path.exists(cfg.output) and rank is 0:
        os.makedirs(cfg.output)
    else:
        time.sleep(2)

    log_root = logging.getLogger()
    init_logging(log_root, rank, cfg.output)
    trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoaderX(
        local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size,
        sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True)

    dropout = 0.4 if cfg.dataset is "webface" else 0

    backbone = MODEL1(network=args.network, embedding_size=cfg.embedding_size, batch_size=cfg.batch_size,
                      dropout=dropout, fp16=cfg.fp16).to(local_rank)

    for ps in backbone.parameters():
        dist.broadcast(ps, 0)

    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[local_rank], find_unused_parameters=True)
    backbone.train()

    margin_softmax = eval("losses.{}".format(args.loss))()
    module_partial_fc = PartialFC(
        rank=rank, local_rank=local_rank, world_size=world_size, resume=args.resume,
        batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
        sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)

    opt_backbone = torch.optim.SGD(
        params=[{'params': backbone.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)
    opt_pfc = torch.optim.SGD(
        params=[{'params': module_partial_fc.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)

    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_backbone, lr_lambda=cfg.lr_func)
    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_pfc, lr_lambda=cfg.lr_func)

    start_epoch = 0
    total_step = int(len(trainset) / cfg.batch_size / world_size * cfg.num_epoch)
    if rank is 0: logging.info("Total Step is: %d" % total_step)

    callback_verification = CallBackVerification(1000, rank, cfg.val_targets, cfg.rec) # 150 for debug, self.frequent = 1000
    callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None) # verbose = 50
    callback_checkpoint = CallBackModelCheckpoint(1000, rank, cfg.output)

    loss = AverageMeter()
    global_step = 0
    grad_scaler = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
    
    mmd_loss = nn.MSELoss(reduction="none")
    kl_loss = DistillationLoss(temp=3)

    for epoch in range(start_epoch, cfg.num_epoch):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            img1, img2 = img[:,:,:,:112], img[:,:,:,112:]

            global_step += 1

            f_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)

            features1 = F.normalize(f_clean_masked)
            features2 = F.normalize(f_occ_masked)

            mmdLoss_v = mmd_loss(features1, features2)
            mmdLoss_v = torch.mean(mmdLoss_v)

            # fc7
            loss_v1 = module_partial_fc.forward_backward(label, fc, opt_pfc)
            loss_v2 = module_partial_fc.forward_backward(label, fc_occ, opt_pfc)

            # fc1
            lossAll = (loss_v1 + loss_v2  + mmdLoss_v).mean()
            lossAll.backward()

            clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
            opt_backbone.step()

            opt_pfc.step()
            module_partial_fc.update()
            opt_backbone.zero_grad()
            opt_pfc.zero_grad()
            loss.update(lossAll, 1)
            callback_logging(global_step, loss, epoch, cfg.fp16, grad_scaler)
            callback_verification(global_step, backbone)
            callback_checkpoint(global_step, backbone, module_partial_fc)
        scheduler_backbone.step()
        scheduler_pfc.step()
    dist.destroy_process_group()

Does anyone have any clue? really appreciated
Dear Pytorch master, my old friend, @ptrblck , do you have any suggestions?
Maybe the codes too long?

My first suspicion was probably this, but it does look like all outputs are participating in loss computation. Although, to double check this can you share the code for PartialFC since it is used in the loss computation?

yes, of course sure.
Thank you so much for paying attention to my post.

the partial fc in my code was just copied from another popular opensource repository for face recognition.

in fact, the major part of my code is based on this version of pytorch implementation for Arcface face recognition method. But maybe for friends who are not famillar with Arcface, they can ignore this name since it just implemented a type of loss function.

longing to find the cause of this problem with you and thanks again for your time!

Could you temporarily get rid of loss_v1 and loss_v2, and only call backward on mmdLoss_v, and skip all lines related to module_partial_fc and opt_pfc, and see if still problem persists?

I think the problem might be coming from this line: insightface/partial_fc.py at master · deepinsight/insightface · GitHub. dist.all_gather doesn’t perform autograd recording, so from an autograd point of view features is not used to produce loss and this might be causing the issue. One way to validate this would be to initialize total_features with copies of features (so its recorded as part of autograd) and see if that resolves the issue?

good insight!
may i ask how to ’ initialize total_features with copies of features’ exactly? simply use copy.deepcopy will cause error…

RuntimeError("Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment"

another question is, does all the distributed operators not record autograd? like dist.all_reduce and so on… where can i get information about these characteristics (record or not)? i didn’t see that in official doc

thank you!

Looking at again, I am abit confused about your code. forward_backward function returns the pair (grad,loss), but you have treated the result (loss_v1 and loss_v2) as simple tensors. Maybe you pointed to the wrong commit of the insightface repository ?

Dear @mrzzd, thanks for your careful check, it’s my fault and sorry (i forgot that i have done a bit modification to the original partial_fc.py). Now i pasted the partial_fc.py here:

If you have any new discovery, please tell me. thank you!

import logging
import os

import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter


class PartialFC(Module):
    """
    Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
    Partial FC: Training 10 Million Identities on a Single Machine
    See the original paper:
    https://arxiv.org/abs/2010.05222
    """

    @torch.no_grad()
    def __init__(self, rank, local_rank, world_size, batch_size, resume,
                 margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
        super(PartialFC, self).__init__()
        #
        self.num_classes: int = num_classes
        self.rank: int = rank
        self.local_rank: int = local_rank
        self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
        self.world_size: int = world_size
        self.batch_size: int = batch_size
        self.margin_softmax: callable = margin_softmax
        self.sample_rate: float = sample_rate
        self.embedding_size: int = embedding_size
        self.prefix: str = prefix
        self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
        self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
        self.num_sample: int = int(self.sample_rate * self.num_local)

        self.weight_name = os.path.join(self.prefix, "rank:{}_softmax_weight.pt".format(self.rank))
        self.weight_mom_name = os.path.join(self.prefix, "rank:{}_softmax_weight_mom.pt".format(self.rank))

        if resume:
            try:
                self.weight: torch.Tensor = torch.load(self.weight_name)
                logging.info("softmax weight resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
                logging.info("softmax weight resume fail!")

            try:
                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
                logging.info("softmax weight mom resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
                logging.info("softmax weight mom resume fail!")
        else:
            self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
            self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
            logging.info("softmax weight init successfully!")
            logging.info("softmax weight mom init successfully!")
        self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)

        self.index = None
        if int(self.sample_rate) == 1:
            self.update = lambda: 0
            self.sub_weight = Parameter(self.weight)
            self.sub_weight_mom = self.weight_mom
        else:
            self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))

    def save_params(self):
        torch.save(self.weight.data, self.weight_name)
        torch.save(self.weight_mom, self.weight_mom_name)

    @torch.no_grad()
    def sample(self, total_label):
        index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
        total_label[~index_positive] = -1
        total_label[index_positive] -= self.class_start
        if int(self.sample_rate) != 1:
            positive = torch.unique(total_label[index_positive], sorted=True)
            if self.num_sample - positive.size(0) >= 0:
                perm = torch.rand(size=[self.num_local], device=self.device)
                perm[positive] = 2.0
                index = torch.topk(perm, k=self.num_sample)[1]
                index = index.sort()[0]
            else:
                index = positive
            self.index = index
            total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
            self.sub_weight = Parameter(self.weight[index])
            self.sub_weight_mom = self.weight_mom[index]

    def forward(self, total_features, norm_weight):
        torch.cuda.current_stream().wait_stream(self.stream)
        logits = linear(total_features, norm_weight)
        return logits

    @torch.no_grad()
    def update(self):
        self.weight_mom[self.index] = self.sub_weight_mom
        self.weight[self.index] = self.sub_weight

    def prepare(self, label, optimizer):
        with torch.cuda.stream(self.stream):
            total_label = torch.zeros(
                size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
            dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
            self.sample(total_label)
            optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
            optimizer.param_groups[-1]['params'][0] = self.sub_weight
            optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
            norm_weight = normalize(self.sub_weight)
            return total_label, norm_weight

    def forward_backward(self, label, features, optimizer):
        total_label, norm_weight = self.prepare(label, optimizer)
        total_features = torch.zeros(
            size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
        total_features.requires_grad = True

        logits = self.forward(total_features, norm_weight)
        logits = self.margin_softmax(logits, total_label)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)

            # calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # calculate prob
            logits_exp.div_(logits_sum_exp)

            # get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            loss_v.requires_grad = True


        return loss_v