Distributed CPU Training - (In-Place?) Error

Attempting to train a model using DDP/distributed training (‘gloo’ backend) but an error is thrown at the very first instance of loss.backward().

The error is:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256]] is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

On implementing torch.autograd.set_detect_anomaly(True), further information is provided:

/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/autograd/__init__.py:156: UserWarning: Error detected in NativeBatchNormBackward0. Traceback of forward call that caused the error:
  File "train_depth_concat_cpu.py", line 469, in <module>
    _, sup_pred_l = model(imgs, step=1)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 888, in forward
    output = self.module(*inputs, **kwargs)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taha_a@WMGDS.WMG.WARWICK.AC.UK/Documents/ss_fsd/CPS/TorchSemiSeg/exp.city/city8.res50v3+.CPS+CutMix/concat/network_depth_concat.py", line 38, in forward
    return self.branch1(data)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taha_a@WMGDS.WMG.WARWICK.AC.UK/Documents/ss_fsd/CPS/TorchSemiSeg/exp.city/city8.res50v3+.CPS+CutMix/concat/network_depth_concat.py", line 71, in forward
    v3plus_feature = self.head(blocks, depth)      # (b, c, h, w)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taha_a@WMGDS.WMG.WARWICK.AC.UK/Documents/ss_fsd/CPS/TorchSemiSeg/exp.city/city8.res50v3+.CPS+CutMix/concat/network_depth_concat.py", line 241, in forward
    f4 = self.last_conv(f3)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 179, in forward
    self.eps,
  File "/home/extraspace/anaconda3/envs/semiseg-test/lib/python3.6/site-packages/torch/nn/functional.py", line 2283, in batch_norm
    input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

I’ve seen another thread where the use of SyncBatchNorm may solve the error but this is currently only supported on GPU distributed training so isn’t really an option.

For the code snippet causing the error:

f4 = self.last_conv(f3)

Attemtpted f4 = self.last_conv(f3.clone()) and f4 = self.last_conv(f3).clone() but to no avail.

I’m also using a standard ResNet implementation as my backbone and made sure inplace is set to False on all blocks/layers. Also made no difference. I don’t think this would be the problem though since the code snippet with the error shows the very first operation (immediately prior to the classifier) causing the error so the autograd doesn’t even get to the backbone while propagating grads…

Any help would be really appreciated :slight_smile:

is last_conv here a nn.Conv2D or is it something else?

1 Like

Yep.

       self.last_conv = nn.Sequential(nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_act(256, momentum=bn_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_act(256, momentum=bn_momentum),
                                       nn.ReLU(),
                                       )

Code is from this repo (though my version is slightly different)

Hey @smth,

Sorry to bother you with a direct tag but was just wondering if you potentially had any ideas as to what could be causing this issue / potential solutions?

Unfortunately, your code is not executable as some definitions are missing such as base_model and thus your custom ResNet implementation. Could you add the missing definitions so that we could run the code and reproduce the error?

Hey @ptrblck, thanks for replying. Sure, please find below (apologies for the excess comments it’s a work in progress):

ResNet:

import functools
import torch.nn as nn

from utils.pyt_utils import load_model


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, norm_layer=None,
                 bn_eps=1e-5, bn_momentum=0.1, downsample=None, inplace=True):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=inplace)
        self.relu_inplace = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
        self.downsample = downsample
        self.stride = stride
        self.inplace = inplace

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if self.inplace:
            out += residual
        else:
            out = out + residual

        out = self.relu_inplace(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1,
                 norm_layer=None, bn_eps=1e-5, bn_momentum=0.1,
                 downsample=None, inplace=True):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = norm_layer(planes * self.expansion, eps=bn_eps,
                              momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=inplace)
        self.relu_inplace = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.inplace = inplace

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        if self.inplace:
            out += residual
        else:
            out = out + residual
        out = self.relu_inplace(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
                 bn_momentum=0.1, deep_stem=False, stem_width=32, inplace=True):
        self.inplanes = stem_width * 2 if deep_stem else 64
        super(ResNet, self).__init__()
        if deep_stem:
            self.conv1 = nn.Sequential(
                nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1,
                          bias=False),
                norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum),
                nn.ReLU(inplace=inplace),
                nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1,
                          padding=1,
                          bias=False),
                norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum),
                nn.ReLU(inplace=inplace),
                nn.Conv2d(stem_width, stem_width * 2, kernel_size=3, stride=1,
                          padding=1,
                          bias=False),
            )
        else:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False)

        self.bn1 = norm_layer(stem_width * 2 if deep_stem else 64, eps=bn_eps,
                              momentum=bn_momentum)
        self.relu = nn.ReLU(inplace=inplace)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, norm_layer, 64, layers[0],
                                       inplace,
                                       bn_eps=bn_eps, bn_momentum=bn_momentum)
        self.layer2 = self._make_layer(block, norm_layer, 128, layers[1],
                                       inplace, stride=2,
                                       bn_eps=bn_eps, bn_momentum=bn_momentum)
        self.layer3 = self._make_layer(block, norm_layer, 256, layers[2],
                                       inplace, stride=2,
                                       bn_eps=bn_eps, bn_momentum=bn_momentum)
        self.layer4 = self._make_layer(block, norm_layer, 512, layers[3],
                                       inplace, stride=2,
                                       bn_eps=bn_eps, bn_momentum=bn_momentum)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):#, nn.SyncBatchNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

        for m in self.modules():
            if isinstance(m, Bottleneck):
                nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, norm_layer, planes, blocks, inplace=True,
                    stride=1, bn_eps=1e-5, bn_momentum=0.1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion, eps=bn_eps,
                           momentum=bn_momentum),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, norm_layer, bn_eps,
                            bn_momentum, downsample, inplace))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                norm_layer=norm_layer, bn_eps=bn_eps,
                                bn_momentum=bn_momentum, inplace=inplace))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        blocks = []
        x = self.layer1(x);
        blocks.append(x)
        x = self.layer2(x);
        blocks.append(x)
        x = self.layer3(x);
        blocks.append(x)
        x = self.layer4(x);
        blocks.append(x)

        return blocks


def resnet18(pretrained_model=None, **kwargs):
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

    if pretrained_model is not None:
        model = load_model(model, pretrained_model)
    return model


def resnet34(pretrained_model=None, **kwargs):
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

    if pretrained_model is not None:
        model = load_model(model, pretrained_model)
    return model


def resnet50(pretrained_model=None, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

    if pretrained_model is not None:
        model = load_model(model, pretrained_model)
    return model


def resnet101(pretrained_model=None, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)

    if pretrained_model is not None:
        model = load_model(model, pretrained_model)
    return model


def resnet152(pretrained_model=None, **kwargs):
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)

    if pretrained_model is not None:
        model = load_model(model, pretrained_model)
    return model

Network (Calling ResNet):

# encoding: utf-8

import sys
import path

sys.path.append('../../../')

import torch
import torch.nn as nn
import torch.nn.functional as F

from config import config
from functools import partial
from collections import OrderedDict
from furnace.base_model import resnet50


from dataloader import get_train_loader, CityScape
from furnace.utils.init_func import init_weight, group_weight
from furnace.engine.lr_policy import WarmUpPolyLR
from furnace.engine.engine import Engine
from furnace.seg_opr.loss_opr import SigmoidFocalLoss, ProbOhemCrossEntropy2d, bce2d
#from furnace.seg_opr.sync_bn import DataParallelModel, Reduce, BatchNorm2d

class Network(nn.Module):
    def __init__(self, num_classes, criterion, norm_layer, pretrained_model=None):
        super(Network, self).__init__()
        self.branch1 = SingleNetwork(num_classes, criterion, norm_layer, pretrained_model)
        self.branch2 = SingleNetwork(num_classes, criterion, norm_layer, pretrained_model)

    def forward(self, data, step=1):
        if not self.training:
            pred1 = self.branch1(data)
            return pred1

        if step == 1:
            return self.branch1(data)
        elif step == 2:
            return self.branch2(data)

class SingleNetwork(nn.Module):
    def __init__(self, num_classes, criterion, norm_layer, pretrained_model=None):
        super(SingleNetwork, self).__init__()
        self.backbone = resnet50(pretrained_model, norm_layer=norm_layer,
                                  bn_eps=config.bn_eps,
                                  bn_momentum=config.bn_momentum,
                                  deep_stem=True, stem_width=64)   
        
        self.stem_width = 64

        self.dilate = 2
        for m in self.backbone.layer4.children(): 
            m.apply(partial(self._nostride_dilate, dilate=self.dilate))                                                  #the apply function only takes in references to functions and then calls the function after, it expects whatever function is passed to it to have only one parameter, m, as the model. It then calls the function with m as the parameter. The reason partial is used is because the _nostride_dilate function has more than one input argument, so to get over that, the partial class takes in a function, and a partial subset of it's input arguments, and wraps a new function around them where the new function's input parameters are what remains (what wasn't defined in the partial call). So the apply function only sees a function with one input argument, m, so it works.
            self.dilate *= 2                                                                                             #this explains the apply fucntion https://stackoverflow.com/questions/55613518/how-does-the-applyfn-function-in-pytorch-work-with-a-function-without-return-s

        self.head = Head(num_classes, norm_layer, config.bn_momentum)
        self.business_layer = []
        self.business_layer.append(self.head)
        self.criterion = criterion

        self.classifier = nn.Conv2d(256, num_classes, kernel_size=1, bias=True)                     #num_classes is number of classes we're predicting, the input to the classifier is a 256 channel feature map
        self.business_layer.append(self.classifier)

    def forward(self, data):
        blocks = self.backbone(data)
        v3plus_feature = self.head(blocks)      # (b, c, h, w)
        b, c, h, w = v3plus_feature.shape

        pred = self.classifier(v3plus_feature)

        b, c, h, w = data.shape
        pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True)
        
        if self.training:
            return v3plus_feature, pred
        return pred

    # @staticmethod
    def _nostride_dilate(self, m, dilate):
        if isinstance(m, nn.Conv2d):
            if m.stride == (2, 2):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)

            else:
                if m.kernel_size == (3, 3):
                    m.dilation = (dilate, dilate)
                    m.padding = (dilate, dilate)


class ASPP(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 dilation_rates=(12, 24, 36),
                 hidden_channels=256,
                 norm_act=nn.BatchNorm2d,
                 pooling_size=None):
        super(ASPP, self).__init__()
        self.pooling_size = pooling_size


        self.pool_img = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels,hidden_channels, kernel_size=1, dilation=1,bias=False))
        self.map_convs = nn.ModuleList([
            nn.Conv2d(in_channels, hidden_channels, 1, dilation=1, bias=False),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[0],
                      padding=dilation_rates[0]),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[1],
                      padding=dilation_rates[1]),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilation_rates[2],
                      padding=dilation_rates[2])
        ])
        self.map_bn = norm_act(hidden_channels * 5)

        self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
        self.global_pooling_bn = norm_act(hidden_channels)

        self.red_conv = nn.Conv2d(hidden_channels * 5, out_channels, 1, bias=False)
        self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
        self.red_bn = norm_act(out_channels)

        self.leak_relu = nn.LeakyReLU()

    def forward(self, x):
        # Map convolutions
        _, _, h, w = x.size()
        out1 = F.interpolate(
            self.pool_img(x), size=(h, w), mode="bilinear", align_corners=True
        )
        aspp_list = [m(x) for m in self.map_convs]
        aspp_list.insert(0,out1)
        out = torch.cat(aspp_list, dim=1)
        out = self.map_bn(out)
        out = self.leak_relu(out)       # add activation layer
        out = self.red_conv(out)
        out = self.leak_relu(out)

        # Global pooling
        #pool = self._global_pooling(x)
        #pool = self.global_pooling_conv(pool)
        #pool = self.global_pooling_bn(pool)

        #pool = self.leak_relu(pool)  # add activation layer

        #pool = self.pool_red_conv(pool)
        #if self.training or self.pooling_size is None:
        #    pool = pool.repeat(1, 1, x.size(2), x.size(3))

        #out += pool
        #out = self.red_bn(out)
        #out = self.leak_relu(out)  # add activation layer
        return out

    def _global_pooling(self, x):
        if self.training or self.pooling_size is None:
            pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
            pool = pool.view(x.size(0), x.size(1), 1, 1)
        else:                                                               #pooling size is always none so does not play a role here
            pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
                            min(try_index(self.pooling_size, 1), x.shape[3]))
            padding = (
                (pooling_size[1] - 1) // 2,
                (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
                (pooling_size[0] - 1) // 2,
                (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
            )

            pool = nn.functional.avg_pool2d(x, pooling_size, stride=1)
            pool = nn.functional.pad(pool, pad=padding, mode="replicate")
        return pool

class Head(nn.Module):
    def __init__(self, classify_classes, norm_act=nn.BatchNorm2d, bn_momentum=0.0003):
        super(Head, self).__init__()

        self.classify_classes = classify_classes
        self.aspp = ASPP(2048, 256, [12, 18, 24], norm_act=norm_act)

        self.reduce = nn.Sequential(
            nn.Conv2d(256, 256, 1, bias=False),
            norm_act(256, momentum=bn_momentum),
            nn.ReLU(),
        )
        self.last_conv = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_act(256, momentum=bn_momentum),
                                       nn.ReLU(),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       norm_act(256, momentum=bn_momentum),
                                       nn.ReLU(),
                                       )

    def forward(self, f_list):
        f = f_list[-1]
        f1 = self.aspp(f)
        low_level_features = f_list[0]
        low_h, low_w = low_level_features.size(2), low_level_features.size(3)
        low_level_features = self.reduce(low_level_features)
        f2 = F.interpolate(f1, size=(low_h, low_w), mode='bilinear', align_corners=True)
        f3 = torch.cat((f2, low_level_features), dim=1) #concatenate depth dimension here?
        f4 = self.last_conv(f3)

        return f4


def count_params(model):
    return sum(p.numel() for p in model.parameters())# if p.requires_grad)

if __name__ == '__main__':

    criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7,
                                       min_kept=50000, use_weight=False)

    model = Network(5, pretrained_model=config.pretrained_model, criterion=criterion,                    #change number of classes to free space only, 2 classes since freespace and non free space,
                    norm_layer=nn.BatchNorm2d)
            
    #print(model)

    #for module in model.branch1.modules():
    #    print(f"Module:", module, "\n")

    print("Number of Params", count_params(model.branch1))

    '''
    model = Network(40, criterion=nn.CrossEntropyLoss(),
                    pretrained_model=None,
                    norm_layer=nn.BatchNorm2d)
    left = torch.randn(2, 3, 128, 128)
    right = torch.randn(2, 3, 128, 128)

    print(model.backbone)

    out = model(left)
    print(out.shape)
    '''

Your code is still using undefined classes. After I removed them and also removed failing lines of code (e.g. your model does not contain a .backbone attribute) a few iterations work for me:

if __name__ == '__main__':
    criterion = nn.CrossEntropyLoss()
    model = Network(5, pretrained_model=None, criterion=criterion,
                    norm_layer=nn.BatchNorm2d)
   
    model = Network(40, criterion=nn.CrossEntropyLoss(),
                    pretrained_model=None,
                    norm_layer=nn.BatchNorm2d)
    left = torch.randn(2, 3, 128, 128)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(10):
        optimizer.zero_grad()
        out = model(left)
        loss = out[0].mean() + out[1].mean()
        loss.backward()
        optimizer.step()
        print('epoch {}, loss {}'.format(epoch, loss.item()))

Do you use torchrun or torch.distributed.launch? And what version of PyTorch are you using?

Also slightly confused about the comment about the backbone it’s defined in the SingeNetwork Class…unless I’ve misunderstood? Also for further clarity I have my Engine (runner wrapper to define instantiate args) and Training script below (in another message since exceeds post character limit)

Engine (furnace.engine.engine):

#!/usr/bin/env python3
# encoding: utf-8
# @Time    : 2018/8/2 下午3:23
# @Author  : yuchangqian
# @Contact : changqian_yu@163.com
# @File    : engine.py

from datetime import datetime
import os
import os.path as osp
import time
import argparse

import shutil
import torch
import torch.distributed as dist
import _datetime as dt

from .logger import get_logger
from .version import __version__
from utils.pyt_utils import load_model, parse_devices, extant_file, link_file, \
    ensure_dir

logger = get_logger()


class State(object):
    def __init__(self):
        self.epoch = 0
        self.iteration = 0
        self.dataloader = None
        self.model = None
        self.optimizer = None
        self.optimizer_l = None
        self.optimizer_r = None
        #self.loss = 0

    def register(self, **kwargs):
        for k, v in kwargs.items():
            # assert k in ['epoch', 'iteration', 'dataloader', 'model',
            #              'optimizer']
            setattr(self, k, v)

# Engine object has a state object associated with it that stores
# information about the model and training set up


class Engine(object):
    def __init__(self, custom_parser=None):  # define args parser when the class is called
        self.version = __version__  # version stored in version.py file
        logger.info(
            "PyTorch Version {}, Furnace Version {}".format(torch.__version__,
                                                            self.version))
        self.state = State()  # initialise a new State object
        #self.devices = None
        #self.distributed = False

        if custom_parser is None:
            # use default Python arg parser if Argument Parses is not defined
            # when initialising class
            self.parser = argparse.ArgumentParser()
        else:
            # if defined, make sure the parser defined is an instance of the
            # default Python Arg Parser
            assert isinstance(custom_parser, argparse.ArgumentParser)
            self.parser = custom_parser

        self.inject_default_parser()  # some .add_argument statements defiend in this function
        self.args = self.parser.parse_args()  # parse the args and store in self.args

        # one of the arguments from the command line args is continue_fpath
        # (defined in inject_default_parser) and states whether to continue
        # from one certain checkpoint
        if self.args.continue_fpath is not None and os.path.exists(
                self.args.continue_fpath):
            self.continue_state_object = self.args.continue_fpath
        else:
            self.continue_state_object = None
        print('continue_state_object: ', self.continue_state_object)

        print(str(os.environ['CPU_DIST_ONLY']))

        if str(os.environ['CPU_DIST_ONLY']) == 'False':
            self.cpu_only = False
        else:
            self.cpu_only = True

        self.world_size = int(os.environ['WORLD_SIZE'])
        # print(self.world_size)

        if self.world_size > 1:  # checking if there is an ENVIRONMENT variable called WORLD_SIZE
            self.distributed = True  # returns bool
            #os.environ['OMP_NUM_THREADS'] = '4'
            self.local_rank = self.args.local_rank
            self.world_size = int(os.environ['WORLD_SIZE'])
            if not bool(os.environ['CPU_DIST_ONLY']):
                torch.cuda.set_device(self.local_rank)
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = self.args.port
            os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
            if not bool(os.environ['CPU_DIST_ONLY']):
                dist.init_process_group(backend="nccl", init_method='env://')
            else:
                dist.init_process_group(
                    backend="gloo",
                    rank=self.local_rank,
                    world_size=self.world_size)
                print('gloo initialised')

            # creates a list which just outputs  [0, 1, 2, 3, 4, ...] for
            # number of devices
            self.devices = [i for i in range(self.world_size)]

        else:
            self.distributed = False
            self.devices = 0  # parse_devices(self.args.devices)
            self.local_rank = 0

    def inject_default_parser(self):
        p = self.parser  # assigning the parser set above to the variable p
        p.add_argument('-d', '--devices', default='',
                       help='set data parallel training')
        # p.add_argument('-c', '--continue', type=extant_file,
        #                metavar="FILE",
        #                dest="continue_fpath",
        #                help='continue from one certain checkpoint')
        p.add_argument('-c', '--continue', type=str,
                       dest="continue_fpath",
                       help='continue from one certain checkpoint')
        p.add_argument('--local_rank', default=0, type=int,
                       help='process rank on node')
        p.add_argument('-p', '--port', type=str,
                       default='12355',
                       dest="port",
                       help='port for init_process_group')
        p.add_argument('--debug', default=0, type=int,
                       help='whether to use the debug mode')

    def register_state(self, **kwargs):
        self.state.register(**kwargs)

    def update_iteration(self, epoch, iteration):
        self.state.epoch = epoch
        self.state.iteration = iteration

    def save_checkpoint(self, path):
        logger.info("Saving checkpoint to file {}".format(path))
        t_start = time.time()

        state_dict = {}

        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in self.state.model.state_dict().items():
            key = k
            if k.split('.')[0] == 'module':
                key = k[7:]
                print('key', key)
            new_state_dict[key] = v
        state_dict['model'] = new_state_dict

        if self.state.optimizer is not None:
            state_dict['optimizer'] = self.state.optimizer.state_dict()
        if self.state.optimizer_l is not None:
            state_dict['optimizer_l'] = self.state.optimizer_l.state_dict()
        if self.state.optimizer_r is not None:
            state_dict['optimizer_r'] = self.state.optimizer_r.state_dict()
        state_dict['epoch'] = self.state.epoch
        state_dict['iteration'] = self.state.iteration
        #state_dict['loss'] = self.state.loss

        t_iobegin = time.time()
        torch.save(state_dict, path)
        del state_dict
        del new_state_dict
        t_end = time.time()
        logger.info(
            "Save checkpoint to file {}, "
            "Time usage:\n\tprepare snapshot: {}, IO: {}".format(
                path, t_iobegin - t_start, t_end - t_iobegin))

    def link_tb(self, source, target):
        ensure_dir(source)
        ensure_dir(target)
        link_file(source, target)

    def save_and_link_checkpoint(
            self,
            snapshot_dir,
            log_dir,
            log_dir_link,
            epoch,
            name=None):
        ensure_dir(snapshot_dir)
        dt = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        if not osp.exists(log_dir_link):  # test whether a path exists
            link_file(log_dir, log_dir_link)
        if name is None:
            current_epoch_checkpoint = osp.join(
                snapshot_dir, '{}-epoch-{}.pth'.format(dt, epoch))
        else:
            current_epoch_checkpoint = osp.join(snapshot_dir, '{}.pth'.format(
                name))

        ''' 如果旧文件存在,先删除 - If the old file exists, delete it first '''
        if os.path.exists(current_epoch_checkpoint):
            os.remove(current_epoch_checkpoint)

        self.save_checkpoint(current_epoch_checkpoint)
        last_epoch_checkpoint = osp.join(snapshot_dir,
                                         f'{dt}_epoch-last.pth')
        # link_file(current_epoch_checkpoint, last_epoch_checkpoint)
        try:
            shutil.copy(current_epoch_checkpoint, last_epoch_checkpoint)
        except BaseException:
            pass

    def restore_checkpoint(self):
        t_start = time.time()
        if self.distributed:
            tmp = torch.load(self.continue_state_object,
                             map_location=lambda storage, loc: storage.cuda(
                                 self.local_rank))
        else:
            tmp = torch.load(self.continue_state_object)
        t_ioend = time.time()

        self.state.model = load_model(self.state.model, tmp['model'],
                                      True)
        if 'optimizer_l' in tmp:
            self.state.optimizer_l.load_state_dict(tmp['optimizer_l'])
        if 'optimizer_r' in tmp:
            self.state.optimizer_r.load_state_dict(tmp['optimizer_r'])
        if 'optimizer' in tmp:
            self.state.optimizer.load_state_dict(tmp['optimizer'])
        self.state.epoch = tmp['epoch'] + 1
        self.state.iteration = tmp['iteration']
        del tmp
        t_end = time.time()
        logger.info(
            "Load checkpoint from file {}, "
            "Time usage:\n\tIO: {}, restore snapshot: {}".format(
                self.continue_state_object,
                t_ioend - t_start,
                t_end - t_ioend))

    def __enter__(self):
        return self

    def __exit__(self, type, value, tb):
        torch.cuda.empty_cache()
        if type is not None:
            logger.warning(
                "A exception occurred during Engine initialization, "
                "give up running process")
            return False

Training Script:

from __future__ import division
from custom_collate import SegCollate
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
from furnace.seg_opr.metric import hist_info, compute_score, recall_and_precision
from furnace.engine.evaluator import Evaluator
from furnace.utils.pyt_utils import load_model
from furnace.seg_opr.loss_opr import SigmoidFocalLoss, ProbOhemCrossEntropy2d, bce2d
from furnace.engine.engine import Engine
from furnace.engine.lr_policy import WarmUpPolyLR
from furnace.utils.visualize import print_iou, show_img
from furnace.utils.init_func import init_weight, group_weight
from furnace.utils.img_utils import generate_random_uns_crop_pos
import random
import cv2
from eval_depth_concat import SegEvaluator
from dataloader_depth_concat import CityScape
from network_depth_concat import Network, count_params
from dataloader_depth_concat import get_train_loader
from config import config
from matplotlib import colors
from PIL import Image
from dataloader_depth_concat import TrainValPre
from torch.utils import data
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.functional as F
import torch.nn as nn
import torchvision.utils
import torch
import numpy as np
from datetime import datetime as dt
from tqdm import tqdm
import argparse
import math
import time
import uuid
import os
import os.path as osp
import sys
sys.path.append('../../../')
sys.path.append('../')

experiment_name = str(config.nepochs) + 'E_SS' + str(config.labeled_ratio) + '_L' + \
    str(config.lr) + '_ConcatD' + '_CPUDist_' + str(config.image_height) + 'size'


if os.getenv('debug') is not None:
    is_debug = os.environ['debug']
else:
    is_debug = False


def set_random_seed(seed, deterministic=False):
    """Set random seed."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if not engine.cpu_only:
        torch.cuda.manual_seed_all(seed)
        if deterministic:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False


def compute_metric(results):
    hist = np.zeros((config.num_classes, config.num_classes))
    correct = 0
    labeled = 0
    count = 0
    for d in results:
        hist += d['hist']
        correct += d['correct']
        labeled += d['labeled']
        count += 1

    iu, mean_IU, _, mean_pixel_acc = compute_score(hist, correct,
                                                   labeled)
    # changed from the variable dataset to the class directly so this function
    # can now be called without first initialising the eval file
    print(len(CityScape.get_class_names()))

  


def viz_image(imgs, gts, pred, step, epoch, name, step_test=None):
    pass
         

collate_fn = SegCollate()

if config.depthmix:
    import mask_gen_depth
    mask_generator = mask_gen_depth.BoxMaskGenerator(
        prop_range=config.depthmix_mask_prop_range,
        n_boxes=config.depthmix_boxmask_n_boxes,
        random_aspect_ratio=not config.depthmix_boxmask_fixed_aspect_ratio,
        prop_by_area=not config.depthmix_boxmask_by_size,
        within_bounds=not config.depthmix_boxmask_outside_bounds,
        invert=not config.depthmix_boxmask_no_invert)
    mask_collate_fn = SegCollate(batch_aug_fn=None)
else:
    import mask_gen
    mask_generator = mask_gen.BoxMaskGenerator(
        prop_range=config.cutmix_mask_prop_range,
        n_boxes=config.cutmix_boxmask_n_boxes,
        random_aspect_ratio=not config.cutmix_boxmask_fixed_aspect_ratio,
        prop_by_area=not config.cutmix_boxmask_by_size,
        within_bounds=not config.cutmix_boxmask_outside_bounds,
        invert=not config.cutmix_boxmask_no_invert)
    add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(
        mask_generator
    )
    mask_collate_fn = SegCollate(batch_aug_fn=add_mask_params_to_batch)


parser = argparse.ArgumentParser()

with Engine(custom_parser=parser) as engine:
    args = parser.parse_args()

    if not engine.cpu_only:
        cudnn.benchmark = True  # Changed to False due to error

    seed = config.seed

    # if engine.distributed:
    #    seed = engine.local_rank

    set_random_seed(seed)

    pin_memory_flag = not engine.cpu_only

    # data loader + unsupervised data loader
    train_loader, train_sampler = get_train_loader(engine, CityScape, train_source=config.train_source,
                                                   unsupervised=False, collate_fn=collate_fn, pin_memory_flag=pin_memory_flag)
    unsupervised_train_loader_0, unsupervised_train_sampler_0 = get_train_loader(
        engine, CityScape, train_source=config.unsup_source, unsupervised=True, collate_fn=mask_collate_fn, pin_memory_flag=pin_memory_flag)
    unsupervised_train_loader_1, unsupervised_train_sampler_1 = get_train_loader(
        engine, CityScape, train_source=config.unsup_source_1, unsupervised=True, collate_fn=collate_fn, pin_memory_flag=pin_memory_flag)

    if engine.local_rank == 0:
        # + '/{}'.format(experiment_name) + '/{}'.format(time.strftime("%b%d_%d-%H-%M", time.localtime()))                 #Tensorboard log dir
        tb_dir = config.tb_dir
        logger = SummaryWriter(
            log_dir=tb_dir +
            '/' +
            experiment_name +
            '_' +
            time.strftime(
                "%b%d_%d-%H-%M",
                time.localtime()),
            comment=experiment_name)
        #engine.link_tb(tb_dir, generate_tb_dir)

    #experiment_name = "Road_Only"
    #run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-nodebs{config.batch_size}-tep{config.nepochs}-lr{config.lr}-wd{config.weight_decay}-{uuid.uuid4()}"
    #name = f"{experiment_name}_{run_id}"
    #wandb.init(project=PROJECT, name=name, tags='Road Only', entity = "alitaha")

    path_best = osp.join(config.snapshot_dir, 'epoch-best_loss.pth')

    # config network and criterion
    # this is used for the min kept variable in CrossEntropyLess, basically
    # saying at least 50,000 valid targets per image (but summing them up
    # since the loss for an entire minibatch is computed at once)
    pixel_num = 5000 * config.batch_size // engine.world_size
    criterion = ProbOhemCrossEntropy2d(ignore_label=100, thresh=0.7,  # NUMBER CHANGED TO 5000 from 50000 due to reduction in number of labels since only road labels valid
                                       min_kept=pixel_num, use_weight=False)
    criterion_cps = nn.CrossEntropyLoss(reduction='mean', ignore_index=100)

    if engine.distributed and not engine.cpu_only:
        BatchNorm2d = nn.SyncBatchNorm
    else:
        BatchNorm2d = nn.BatchNorm2d

    # WHERE WILL THE DEPTH VALUES BE APPENDED, RESNET ALREADY PRE-TRAINED WITH

    model = Network(config.num_classes, criterion=criterion,  # change number of classes to free space only
                    pretrained_model=config.pretrained_model,
                    norm_layer=BatchNorm2d)  # need to change norm_layer to nn.BatchNorm2d since BatchNorm2d is derived from the furnace package and doesn't seem to work, it's only needed for syncing batches across multiple GPU, may be needed later
    init_weight(model.branch1.business_layer, nn.init.kaiming_normal_,  # to change it back to author's original, change from nn.BatchNorm2d to BatchNorm2d (which is referenced in the import statement above)
                BatchNorm2d, config.bn_eps, config.bn_momentum,
                mode='fan_in', nonlinearity='relu')
    init_weight(model.branch2.business_layer, nn.init.kaiming_normal_,
                BatchNorm2d, config.bn_eps, config.bn_momentum,
                mode='fan_in', nonlinearity='relu')

    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}

    trainval_pre = TrainValPre(
        config.image_mean,
        config.image_std,
        config.dimage_mean,
        config.dimage_std)
    test_dataset = CityScape(data_setting, 'trainval', trainval_pre)

    test_loader = data.DataLoader(test_dataset,
                                  batch_size=1,
                                  num_workers=config.num_workers,
                                  drop_last=True,
                                  shuffle=False,
                                  pin_memory=pin_memory_flag,
                                  sampler=train_sampler)

    base_lr = config.lr
    if engine.distributed:
        base_lr = config.lr

    params_list_l = []
    # params grouped into two groups, one with weight decay and without weight
    # decay, then returned in a 2 length list and each elements in the list
    # are the parameters with the associated learning rate (this is how
    # PyTorch is designed to handle different learning rates for different
    # module parameters)
    params_list_l = group_weight(
        params_list_l,
        model.branch1.backbone,
        BatchNorm2d,
        base_lr)
    for module in model.branch1.business_layer:
        params_list_l = group_weight(params_list_l, module, BatchNorm2d,
                                     base_lr)        # head lr * 10

    optimizer_l = torch.optim.SGD(params_list_l,
                                  lr=base_lr,
                                  momentum=config.momentum,
                                  weight_decay=config.weight_decay)

    params_list_r = []
    params_list_r = group_weight(params_list_r, model.branch2.backbone,
                                 BatchNorm2d, base_lr)
    for module in model.branch2.business_layer:
        params_list_r = group_weight(params_list_r, module, BatchNorm2d,
                                     base_lr)        # head lr * 10

    optimizer_r = torch.optim.SGD(params_list_r,
                                  lr=base_lr,
                                  momentum=config.momentum,
                                  weight_decay=config.weight_decay)

    # config lr policy
    total_iteration = config.nepochs * config.niters_per_epoch
    lr_policy = WarmUpPolyLR(
        base_lr,
        config.lr_power,
        total_iteration,
        config.niters_per_epoch *
        config.warm_up_epoch)

    if engine.distributed:
        print('distributed !!')
        if not engine.cpu_only:
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = torch.device("cpu")

        model = DDP(model).to(device)  # , device_ids=[rank])
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # model = DataParallelModel(model, device_ids=engine.devices)
        # #should I comment out since only one GPU?
        model.to(device)

    engine.register_state(
        dataloader=train_loader,
        model=model,
        optimizer_l=optimizer_l,
        optimizer_r=optimizer_r,
        loss=0)

    if engine.continue_state_object:
        engine.restore_checkpoint()     # it will change the state dict of optimizer also

    print("Number of Params", count_params(model))

    def val_step(engine, batch):
        model.eval()
        loss_sup_test = 0
        with torch.no_grad():  # Testing
            imgs_test = batch['data'].to(device)
            gts_test = batch['label'].to(device)
            pred_test = model.branch1(imgs_test)
            loss_sup_test = loss_sup_test + criterion(pred_test, gts_test)
        return gts_test, gts_test

    """
    evaluator = i_engine.Engine(val_step)
    cm = ConfusionMatrix(num_classes=2)

    val_metrics = {
    "pixel accuracy": Accuracy(),
    "average precision": Precision(average=True),
    "average recall": Recall(average=True),
    "IoU": IoU(cm),
    "average iou": mIoU(cm),
    "F1 Score": DiceCoefficient(cm), #ignore index needs to be changed to 2 and set for all iou losses and F1
    "Loss": Loss(criterion)
    }

    for name, metric in val_metrics.items():
        metric.attach(evaluator, name)


    def log_val_results(evaluator):
        state = evaluator.run(test_loader)
        return state.metrics

    """

    #model = load_model(model, '/media/taha_a/T7/Datasets/cityscapes/outputs/city/snapshot/snapshot/epoch-18.pth')

    is_debug = False
    step = 0
    iu_last = 0
    mean_IU_last = 0
    mean_pixel_acc_last = 0
    loss_sup_test_last = 0
    model.train()
    print('begin train')

    for epoch in range(engine.state.epoch, config.nepochs):
        if engine.distributed:
            train_sampler.set_epoch(epoch)
            unsupervised_train_sampler_0.set_epoch(epoch)
            unsupervised_train_sampler_1.set_epoch(epoch)
        bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'

        # generate unsupervsied crops
        if config.depthmix:
            uns_crops = []
            for _ in range(config.max_samples):
                uns_crops.append(
                    generate_random_uns_crop_pos(
                        config.img_shape_h,
                        config.img_shape_w,
                        config.image_height,
                        config.image_width))
            unsupervised_train_loader_0.dataset.uns_crops = uns_crops
            unsupervised_train_loader_1.dataset.uns_crops = uns_crops

        if is_debug:
            # the tqdm function will invoke file.write(whatever) to write
            # progress of the bar and here using sys.stdout.write will write to
            # the command window
            pbar = tqdm(range(10), file=sys.stdout, bar_format=bar_format)
        else:
            pbar = tqdm(
                range(
                    config.niters_per_epoch),
                file=sys.stdout,
                bar_format=bar_format)

        #wandb.log({"Epoch": epoch}, step=step)
        if engine.local_rank == 0:
            logger.add_scalar('Epoch', epoch, step)

        # the reason the dataloaders are wrapped in a iterator and then the
        # .next() method is used to load the next images is because the
        # traditional 'for batch in dataloader' can't be used since using tqdm
        dataloader = iter(train_loader)
        # therefore the batch will instead be iterarted within each training
        # loop using dataloader.next()
        unsupervised_dataloader_0 = iter(unsupervised_train_loader_0)
        unsupervised_dataloader_1 = iter(unsupervised_train_loader_1)

        torch.autograd.set_detect_anomaly(True)

        sum_loss_sup = 0
        sum_loss_sup_r = 0
        sum_cps = 0

        ''' supervised part '''
        for idx in pbar:

            optimizer_l.zero_grad()
            optimizer_r.zero_grad()
            engine.update_iteration(epoch, idx)
            start_time = time.time()

            minibatch = dataloader.next()
            unsup_minibatch_0 = unsupervised_dataloader_0.next()
            unsup_minibatch_1 = unsupervised_dataloader_1.next()

            imgs = minibatch['data']
            gts = minibatch['label']
            unsup_imgs_0 = unsup_minibatch_0['data']
            unsup_imgs_1 = unsup_minibatch_1['data']

            if config.depthmix:
                mask_params = mask_generator.generate_depth_masks(
                    imgs.shape[0], imgs.shape[2:4], imgs[:, 3:, :, :], unsup_imgs_0, unsup_imgs_1)
                mask_params = torch.from_numpy(mask_params).long()

            else:
                mask_params = unsup_minibatch_0['mask_params']

            imgs = imgs.to(device)  # .cuda(non_blocking=True)
            gts = gts.to(device)  # .cuda(non_blocking=True)
            unsup_imgs_0 = unsup_imgs_0.to(device)  # .cuda(non_blocking=True)
            unsup_imgs_1 = unsup_imgs_1.to(device)  # .cuda(non_blocking=True)
            mask_params = mask_params.to(device)  # .cuda(non_blocking=True)

            #if step==0: logger.add_graph(model.branch1, imgs[:1,:,:,:])

            # unsupervised loss on model/branch#1
            # this is a mask the size of the image with either a value of 1 or
            # 0
            batch_mix_masks = mask_params
            # augmenting the unsupervised 0 images into the 1 images -this is
            # why we have two unsup loaders so that we can get different random
            # images to augment in every iteration
            unsup_imgs_mixed = unsup_imgs_0 * \
                (1 - batch_mix_masks) + unsup_imgs_1 * batch_mix_masks

            '''

            viz_mix = (unsup_imgs_mixed[0,:3,:,:].squeeze().numpy()*np.expand_dims(np.expand_dims(config.image_std,axis=1),axis=2)+np.expand_dims(np.expand_dims(config.image_mean,axis=1),axis=2))
            viz_mix = viz_mix.transpose(1,2,0)

            fig = plt.figure(figsize=(20, 14))
            fig.add_subplot(1, 3, 1)
            plt.imshow((unsup_imgs_0[0,:3,...].squeeze().numpy()*np.expand_dims(np.expand_dims(config.image_std,axis=1),axis=2)+np.expand_dims(np.expand_dims(config.image_mean,axis=1),axis=2)).transpose(1,2,0))
            fig.add_subplot(1, 3, 2)
            plt.imshow((unsup_imgs_1[0,:3,...].squeeze().numpy()*np.expand_dims(np.expand_dims(config.image_std,axis=1),axis=2)+np.expand_dims(np.expand_dims(config.image_mean,axis=1),axis=2)).transpose(1,2,0))
            fig.add_subplot(1, 3, 3)
            plt.imshow(viz_mix)
            #fig.add_subplot(2, 2, 4)
            #plt.imshow(viz_mix)
            plt.show()

            '''

            with torch.no_grad():
                # Estimate the pseudo-label with branch#1 & supervise branch#2
                # step defines which branch we use
                _, logits_u0_tea_1 = model(unsup_imgs_0, step=1)
                _, logits_u1_tea_1 = model(unsup_imgs_1, step=1)
                logits_u0_tea_1 = logits_u0_tea_1.detach()
                logits_u1_tea_1 = logits_u1_tea_1.detach()
                # Estimate the pseudo-label with branch#2 & supervise branch#1
                _, logits_u0_tea_2 = model(unsup_imgs_0, step=2)
                _, logits_u1_tea_2 = model(unsup_imgs_1, step=2)
                logits_u0_tea_2 = logits_u0_tea_2.detach()
                logits_u1_tea_2 = logits_u1_tea_2.detach()

            # Mix teacher predictions using same mask
            # It makes no difference whether we do this with logits or probabilities as
            # the mask pixels are either 1 or 0
            logits_cons_tea_1 = logits_u0_tea_1 * \
                (1 - batch_mix_masks) + logits_u1_tea_1 * batch_mix_masks
            # getting the pseudo label since it will be the max value
            # probability
            _, ps_label_1 = torch.max(logits_cons_tea_1, dim=1)
            ps_label_1 = ps_label_1.long()
            logits_cons_tea_2 = logits_u0_tea_2 * \
                (1 - batch_mix_masks) + logits_u1_tea_2 * batch_mix_masks
            _, ps_label_2 = torch.max(logits_cons_tea_2, dim=1)
            ps_label_2 = ps_label_2.long()

            unsup_imgs_mixed.to(device)
            # Get student#1 prediction for mixed image
            _, logits_cons_stu_1 = model(unsup_imgs_mixed, step=1)
            # Get student#2 prediction for mixed image
            _, logits_cons_stu_2 = model(unsup_imgs_mixed, step=2)

            cps_loss = criterion_cps(
                logits_cons_stu_1, ps_label_2) + criterion_cps(logits_cons_stu_2, ps_label_1)
            dist.all_reduce(cps_loss, dist.ReduceOp.SUM)
            cps_loss = cps_loss / engine.world_size
            cps_loss = cps_loss * config.cps_weight

            # supervised loss on both models
            _, sup_pred_l = model(imgs, step=1)
            _, sup_pred_r = model(imgs, step=2)

            loss_sup = criterion(sup_pred_l, gts)
            dist.all_reduce(loss_sup, dist.ReduceOp.SUM)
            loss_sup = loss_sup / engine.world_size

            loss_sup_r = criterion(sup_pred_r, gts)
            dist.all_reduce(loss_sup_r, dist.ReduceOp.SUM)
            loss_sup_r = loss_sup_r / engine.world_size
            current_idx = epoch * config.niters_per_epoch + idx
            lr = lr_policy.get_lr(current_idx)

            optimizer_l.param_groups[0]['lr'] = lr
            optimizer_l.param_groups[1]['lr'] = lr
            for i in range(2, len(optimizer_l.param_groups)):
                optimizer_l.param_groups[i]['lr'] = lr
            optimizer_r.param_groups[0]['lr'] = lr
            optimizer_r.param_groups[1]['lr'] = lr
            for i in range(2, len(optimizer_r.param_groups)):
                optimizer_r.param_groups[i]['lr'] = lr

            loss = loss_sup + loss_sup_r + cps_loss
            loss.backward()

            optimizer_l.step()
            optimizer_r.step()
            step = step + 1

            '''
            To obtain label from network (for batch size of 1 - if more than one then dimensions would shift):
            1- Permute the image or prediction first (1,2,0)
            2- torch.argmax(2) - for a batch size of 1 that is the last dimension for label - so c x h x w becomes h x w x c

            gts (label) from data loader is right size ( H x W )
            image from data loader is (C x H x W) -  reason being while image is loaded as H x W x C, it is permuted in the TrainPre Class of the dataloader
            pred from data loader is ( N x C x H x W)
            pred after permute (if extracting one sample only) is ( H x W x C )
            pred after permute and argmax(2) is ( H x W )

            *- Label doesn't need to be permuted from data loader, however either
            '''

            print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \
                        + ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \
                        + ' lr=%.2e' % lr \
                        + ' loss_sup=%.2f' % loss_sup.item() \
                        + ' loss_sup_r=%.2f' % loss_sup_r.item() \
                        + ' loss_cps=%.4f' % cps_loss.item()

            sum_loss_sup += loss_sup.item()
            sum_loss_sup_r += loss_sup_r.item()
            sum_cps += cps_loss.item()

            if engine.local_rank == 0 and step % 20 == 0:
                #wandb.log({f"Train/Loss_Sup_R": loss_sup_r}, step=step)
                #wandb.log({f"Train/Loss_Sup_L": loss_sup}, step=step)
                #wandb.log({f"Train/Loss_CPS": cps_loss}, step=step)
                #wandb.log({f"Train/Total Loss": loss}, step=step)
                logger.add_scalar('train_loss_sup', loss_sup, step)
                logger.add_scalar('train_loss_sup_r', loss_sup_r, step)
                logger.add_scalar('train_loss_cps', cps_loss, step)

                if step % 100 == 0:
                    viz_image(
                        imgs,
                        gts,
                        sup_pred_l,
                        step,
                        epoch,
                        minibatch['fn'][0],
                        None)
                    #images = wandb.Image(image_array, caption="Top: Output, Bottom: Input")
                    # wandb.log({"examples": images}

            if step % config.validate_every == 0 or (
                    is_debug and step % config.validate_every % 10 == 0):
                all_results = []
                prec_road = []
                prec_non_road = []
                recall_road = []
                recall_non_road = []
                mean_prec = []
                mean_recall = []
                prec_recall_metrics = [
                    prec_road,
                    prec_non_road,
                    recall_road,
                    recall_non_road,
                    mean_prec,
                    mean_recall]
                names = [
                    'prec_road',
                    'prec_non_road',
                    'recall_road',
                    'recall_non_road',
                    'mean_prec',
                    'mean_recall']
                model.eval()
                loss_sup_test = 0
                step_test = 0
                with torch.no_grad():
                    for batch_test in tqdm(
                            test_loader,
                            desc=f"Epoch: {epoch + 1}/{config.nepochs}. Loop: Validation",
                            total=len(test_loader)):

                        step_test = step_test + 1

                        imgs_test = batch_test['data'].to(device)
                        gts_test = batch_test['label'].to(device)
                        pred_test = model.branch1(imgs_test)
                        loss_sup_test = loss_sup_test + \
                            criterion(pred_test, gts_test)
                        pred_test_max = torch.argmax(
                            pred_test[0, :, :, :], dim=0).long().cpu().numpy()  # .permute(1, 2, 0)
                        #pred_test_max = pred_test.argmax(2).cpu().numpy()

                        # limit to only one machine computing, every machine to
                        # add their results to shared lists..rank 0 to compute
                        # and store results
                        hist_tmp, labeled_tmp, correct_tmp = hist_info(
                            config.num_classes, pred_test_max, gts_test[0, :, :].cpu().numpy())
                        p, mean_p, r, mean_r = recall_and_precision(
                            pred_test_max, gts_test[0, :, :].cpu().numpy(), config.num_classes)
                        prec_recall_metrics[0].append(p[0])
                        prec_recall_metrics[1].append(p[1])
                        prec_recall_metrics[2].append(r[0])
                        prec_recall_metrics[3].append(r[1])
                        prec_recall_metrics[4].append(mean_p)
                        prec_recall_metrics[5].append(mean_r)
                        results_dict = {
                            'hist': hist_tmp,
                            'labeled': labeled_tmp,
                            'correct': correct_tmp}
                        all_results.append(results_dict)

                        if epoch + 1 > 20:
                            if step_test % 20 == 0:
                                viz_image(
                                    imgs_test,
                                    gts_test,
                                    pred_test,
                                    step,
                                    epoch,
                                    batch_test['fn'][0],
                                    step_test)
                        elif step_test % 50 == 0:
                            viz_image(
                                imgs_test,
                                gts_test,
                                pred_test,
                                step,
                                epoch,
                                batch_test['fn'][0],
                                step_test)

                if engine.local_rank == 0:
                    iu, mean_IU, _, mean_pixel_acc = compute_metric(
                        all_results)
                    loss_sup_test = loss_sup_test / len(test_loader)

                    if mean_IU > mean_IU_last and loss_sup_test < loss_sup_test_last:
                        if os.path.exists(path_best):
                            os.remove(path_best)
                        engine.save_checkpoint(path_best)

                mean_IU_last = mean_IU
                mean_pixel_acc_last = mean_pixel_acc
                loss_sup_test_last = loss_sup_test
                #metrics = log_val_results(evaluator)
                #pa = metrics["pixel accuracy"]
                #ap = metrics["average precision"]
                #ar = metrics["average recall"]
                #miou = metrics["average iou"]
                #f1 = metrics["F1 Score"]
                #iou = metrics["IoU"]
                #loss = metrics["Loss"]
                #loss_average = loss / len(test_loader)
                #logger.add_scalar('Val/Pixel_Accuracy', pa, step)
                #logger.add_scalar('Val/Average_Precision', ap, step)
                #logger.add_scalar('Val/Average_Recall', ar, step)
                #logger.add_scalar('Val/mIoU', miou, step)
                #logger.add_scalar('Val/F1_Score', (f1[0]+f1[1])/len(f1), step)
                print('Supervised Training Validation Set Loss', loss)
                #print(f"Validation Metrics after {step} steps: \nPixel Accuracy {pa}\nAverage Precision {ap}\nAverage Recall {ar}\nmIoU {miou}\nIoU {iou}\nF1 Score {f1}")
                _ = print_iou(iu, mean_pixel_acc,
                              CityScape.get_class_names(), True)
                logger.add_scalar('trainval_loss_sup', loss, step)
                logger.add_scalar(
                    'Val/Mean_Pixel_Accuracy',
                    mean_pixel_acc * 100,
                    step)
                logger.add_scalar('Val/Mean_IoU', mean_IU * 100, step)
                logger.add_scalar('Val/IoU_Road', iu[0] * 100, step)
                logger.add_scalar('Val/IoU_NonRoad', iu[1] * 100, step)

                for i, n in enumerate(prec_recall_metrics):
                    prec_recall_metrics[i] = sum(n) / len(n)
                    logger.add_scalar(
                        f'Val/{names[i]}',
                        round(
                            prec_recall_metrics[i] *
                            100,
                            2),
                        step)
                f1_score = (
                    2 * prec_recall_metrics[4] * prec_recall_metrics[5]) / (
                    prec_recall_metrics[4] + prec_recall_metrics[5])
                logger.add_scalar('Val/F1 Score', round(f1_score, 2), step)
                logger.add_scalar(
                    'Val/Precision vs Recall',
                    round(
                        prec_recall_metrics[4] *
                        100,
                        2),
                    round(
                        prec_recall_metrics[5] *
                        100,
                        2))

                model.train()

            pbar.set_description(print_str, refresh=False)

            end_time = time.time()

        # if engine.distributed and (engine.local_rank == 0):
        #logger.add_scalar('train_loss_sup', sum_loss_sup / len(pbar), epoch)
        #logger.add_scalar('train_loss_sup_r', sum_loss_sup_r / len(pbar), epoch)
        #logger.add_scalar('train_loss_cps', sum_cps / len(pbar), epoch)

        '''
        if azure and engine.local_rank == 0:
            run.log(name='Supervised Training Loss', value=sum_loss_sup / len(pbar))
            run.log(name='Supervised Training Loss right', value=sum_loss_sup_r / len(pbar))
            run.log(name='Supervised Training Loss CPS', value=sum_cps / len(pbar))
        '''

        # if '''(epoch > config.nepochs // 6) and''' (epoch %
        # config.snapshot_iter == 0) or (epoch == config.nepochs - 1):

        if engine.distributed and (engine.local_rank == 0):
            engine.save_and_link_checkpoint(config.snapshot_dir,
                                            config.log_dir,
                                            config.log_dir_link, epoch)
        elif not engine.distributed:
            engine.save_and_link_checkpoint(config.snapshot_dir,
                                            config.log_dir,
                                            config.log_dir_link, epoch)

Sorry, I might not have been clear enough. In order to help debugging your issue, I need to see instructions and be able to copy/paste and execute the code. Right now I have to rip out undefined classes, don’t know how you are launching the workload, how to reproduce the issue, and thus cannot help.

Sorry yes makes sense, I’ve provided the training script and engine file which sets up the distributed workers. Also included the shell script below which I use to launch training (I’m aware torchrun is the recommended solution and torch.distributed.launch is going to be deprecated but torchrun produces a ‘Address Already in Use’ error which torch.distributed.launch does not (tried multiple addresses and made sure they were unoccupied but it still produced the error so that’s why I’m using torch.distributed.launch)

Shell Script:

nvidia-smi

export NGPUS=1
export learning_rate=0.002
export batch_size=2
export snapshot_iter=2
export epochs=35
export ratio=16
export WORLD_SIZE=$batch_size
export CPU_DIST_ONLY='True'

export volna="/home/extraspace/Datasets/Datasets/cityscapes/city/"
export OUTPUT_PATH="/home/extraspace/Runs/CPS/Semi/1-$ratio/"
export snapshot_dir="/home/extraspace/Runs/CPS/Semi/1-$ratio/depth_concat/$learning_rate_$epochs"


python -m torch.distributed.launch --nnodes=1 --nproc_per_node=$WORLD_SIZE train_depth_concat_cpu.py 

Config File (config.py):

# encoding: utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import sys
import time
import numpy as np
from easydict import EasyDict as edict
import argparse


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


C = edict()
config = C
cfg = C

C.seed = 12345

remoteip = os.popen('pwd').read()
if os.getenv('volna') is not None:
    C.volna = os.environ['volna']
else:
    # the path to the data dir.
   

"""please config ROOT_dir and user when u first using"""
C.repo_name = 'TorchSemiSeg'
C.abs_dir = osp.realpath(".")
C.this_dir = C.abs_dir.split(osp.sep)[-1]


#C.root_dir = C.abs_dir[:C.abs_dir.index(C.repo_name) + len(C.repo_name)]
C.root_dir = /Documents/ss_fsd/CPS/TorchSemiSeg'
C.log_dir = '/home/extraspace/Logs'
C.tb_dir = C.log_dir  # osp.abspath(osp.join(C.log_dir, "tb"))

C.log_dir_link = osp.join(C.abs_dir, 'log')

# snapshot dir that stores checkpoints
if os.getenv('snapshot_dir'):
    C.snapshot_dir = osp.join(os.environ['snapshot_dir'], "snapshot")
else:
    C.snapshot_dir = osp.abspath(osp.join(C.log_dir, "snapshot"))

exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
C.log_file = C.log_dir + '/log_' + exp_time + '.log'
C.link_log_file = C.log_file + '/log_last.log'
C.val_log_file = C.log_dir + '/val_' + exp_time + '.log'
C.link_val_log_file = C.log_dir + '/val_last.log'

"""Data Dir and Weight Dir"""
C.dataset_path = C.volna  # changed so path is the external drive
C.img_root_folder = C.dataset_path
C.gt_root_folder = C.dataset_path
C.pretrained_model = /Documents/ss_fsd/CPS/TorchSemiSeg/DATA/pytorch-weight/resnet50_v1c.pth'


"""Path Config"""


def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)


add_path(osp.join(C.root_dir, 'furnace'))


''' Experiments Setting '''
C.labeled_ratio = int(os.environ['ratio'])
C.train_source = osp.join(
    C.dataset_path, "config_new/subset_train/train_aug_labeled_1-{}.txt".format(C.labeled_ratio))
C.train_source = osp.join(
    C.dataset_path, "config_new/subset_train/train_aug_labeled_1-{}.txt".format(C.labeled_ratio))
C.unsup_source = osp.join(
    C.dataset_path, "config_new/subset_train/train_aug_unlabeled_1-{}.txt".format(C.labeled_ratio))
C.unsup_source_1 = osp.join(
    C.dataset_path,
    "config_new/subset_train/train_aug_unlabeled_1-{}_shuffle.txt".format(
        C.labeled_ratio))
C.eval_source = osp.join(C.dataset_path, "config_new/val.txt")
C.test_source = osp.join(C.dataset_path, "config_new/test.txt")
C.demo_source = osp.join(C.dataset_path, "config_new/demo.txt")

C.is_test = False
C.fix_bias = True
C.bn_eps = 1e-5
C.bn_momentum = 0.1

C.cps_weight = 5

"""Cutmix Config"""
C.cutmix_mask_prop_range = (0.25, 0.5)
C.cutmix_boxmask_n_boxes = 1
C.cutmix_boxmask_fixed_aspect_ratio = False
C.cutmix_boxmask_by_size = False
C.cutmix_boxmask_outside_bounds = False
C.cutmix_boxmask_no_invert = False

"""DepthMix Config"""
C.depthmix = False
C.depthmix_mask_prop_range = (0.99, 1.0)
C.depthmix_boxmask_n_boxes = 1
C.depthmix_boxmask_fixed_aspect_ratio = False
C.depthmix_boxmask_by_size = False
C.depthmix_boxmask_outside_bounds = False
C.depthmix_boxmask_no_invert = False

"""Image Config"""
C.num_classes = 2                              # need to change for training free space detection
C.background = 100  # background changed to the class for the padding when cropping
C.image_mean = np.array([0.485, 0.456, 0.406])  # 0.485, 0.456, 0.406
C.image_std = np.array([0.229, 0.224, 0.225])
C.dimage_mean = 22.779  # Update this based on additoin of val and test images
C.dimage_std = 19.110
C.image_height = 600
C.image_width = 600
# if ratio is 8, becomes 371 (// returns the int of the division)
C.num_train_imgs = 2975 // C.labeled_ratio
C.num_eval_imgs = 500
C.num_unsup_imgs = 2975 - C.num_train_imgs  # if ratio is 8, becomes 2604
C.crop_pos = None
C.img_shape_h = 1024
C.img_shape_w = 2048

"""Train Config"""
if os.getenv('learning_rate'):
    C.lr = float(os.environ['learning_rate'])
else:
    C.lr = 0.002

if os.getenv('batch_size'):
    C.batch_size = int(os.environ['batch_size'])
else:
    C.batch_size = 4

C.lr_power = 0.9
C.momentum = 0.9
C.weight_decay = 1e-4

# 35 #122.8 epochs to equal number of iterations for supervised baseline.
# original - 137
C.nepochs = int(os.environ['epochs'])
C.max_samples = max(C.num_train_imgs, C.num_unsup_imgs)
C.cold_start = 0
C.niters_per_epoch = C.max_samples // C.batch_size  # 2604 / 2 for me
C.fully_sup_iters = C.num_train_imgs // C.batch_size

C.num_workers = 8
print(
    bcolors.WARNING +
    f'\n\n\n-------NUMBER OF WORKERS SET TO {C.num_workers}!!!!! CHANGE BACK FOR GPU TRAINING IF 0-------\n\n\n' +
    bcolors.WARNING)
# [1, 1.5, 1.75, 2.0]#[0.5, 0.75, 1, 1.5, 1.75, 2.0]
C.train_scale_array = None

"""Eval Config"""
C.eval_iter = 30
C.eval_stride_rate = 2 / 3
C.eval_scale_array = [1, ]  # 0.5, 0.75, 1, 1.5, 1.75
C.eval_flip = False
C.eval_base_size = 800
C.eval_crop_size = [1024, 2048]

"""Display Config"""
if os.getenv('snapshot_iter'):
    C.snapshot_iter = int(os.environ['snapshot_iter'])
else:
    C.snapshot_iter = 2
C.record_info_iter = 20
C.display_iter = 50
C.warm_up_epoch = 0  # experiment with warm up epoch
C.validate_every = 550  # C.max_samples

Dataloader:

from PIL import Image
from furnace.utils.visualize import print_iou, show_img
from furnace.datasets.BaseDataset import BaseDataset
from furnace.utils.img_utils import generate_random_crop_pos, random_crop_pad_to_shape
from config import config as config
import random
from torch.utils import data
import numpy as np
import torch
import cv2
import os
import sys

for n in range(1, 4):
    m = '../'
    sys.path.append(m * n)


'''
Functions below called by dataloader for transformations like:
- Normalisation ( transforming image from 0-255 into 0-1, changing mean and std_dev )
- Random mirroring of the image
- Random scaling of the image
- Semantic Edge Detector: Gets the GT image, masks all pixels with a value of 255 and converts to 0, then identifies edges of remaining pixels and thickens them.
        final output is thickened edges of ("wanted") semantic objects

'''


def normalize(img, mean, std):
    # pytorch pretrained model need the input range: 0-1
    img = img.astype(np.float32) / 255.0
    img = img - mean
    img = img / std

    return img


def random_mirror(img, gt=None):
    if random.random() >= 0.5:
        img = cv2.flip(img, 1)
        if gt is not None:
            gt = cv2.flip(gt, 1)

    return img, gt


def random_scale(img, gt=None, scales=None):
    scale = random.choice(scales)
    # scale = random.uniform(scales[0], scales[-1])
    sh = int(img.shape[0] * scale)
    sw = int(img.shape[1] * scale)
    img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)

    if gt is not None:
        gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)

    return img, gt, scale


def SemanticEdgeDetector(gt):
    # Since only condition is given, a tuple is returned with the indices
    # where the condition is True.
    id255 = np.where(gt == 255)
    no255_gt = np.array(gt)  # converting ground truth image into numpy array
    # all indices where gt has a value of 255 are assigned a new value of 0
    no255_gt[id255] = 0
    # Using Canny edge detector to find image edges 5,5 are the values used
    # for hysteris thresholding,
    cgt = cv2.Canny(no255_gt, 5, 5, apertureSize=7)
    # in this case they are the same so no hysterisis thresholding is
    # provided, instead, all values greater than 5 are an edge and those less
    # are not, apreture size is the size of the Sobel kernel used for the edge
    # detection, Horizontal & Vertical are done independently the combined
    # using sqrt(G_horiz^2 + G_vert^2)
    edge_radius = 7
    # Creates a 7x7 structuring element to be used for dilation below
    edge_kernel = cv2.getStructuringElement(
        cv2.MORPH_RECT, (edge_radius, edge_radius))
    # After edges identified using the edge detector, they are enlarged using
    # the dilation function
    cgt = cv2.dilate(cgt, edge_kernel)
    # print(cgt.max(), cgt.min())
    cgt[cgt > 0] = 1  # Following the dilation, any area with a non zero value is set to one to avoid different intensities of the edges, only interested in binary edges
    return cgt


class TrainPre(
        object):  # This class preprocesses the images using the functions above for TRAINING only
    def __init__(self, img_mean, img_std):
        self.img_mean = img_mean
        self.img_std = img_std

    def __call__(self, img, gt=None):
        # gt = gt - 1     # label 0 is invalid, this operation transfers label
        # 0 to label 255
        img, gt = random_mirror(img, gt)
        if config.train_scale_array is not None:
            img, gt, scale = random_scale(img, gt, config.train_scale_array)

        # Need to experiment with whether using mean and std dev for the depth
        # images adds value
        img = normalize(img, self.img_mean, self.img_std)

        if gt is not None:
            cgt = SemanticEdgeDetector(gt)
        else:
            cgt = None

        crop_size = (config.image_height, config.image_width)
        crop_pos = generate_random_crop_pos(img.shape[:2], crop_size)

        p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0)
        if gt is not None:
            # ignore label for image padding changed to 100, all other labels
            # that are 255 are now used for training
            p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 100)
            # ignore label for image padding changed to 100, all other labels
            # that are 255 are now used for training
            p_cgt, _ = random_crop_pad_to_shape(cgt, crop_pos, crop_size, 100)
        else:
            p_gt = None
            p_cgt = None

        # colour channel moved to first dimension (from H x W x C to C x H x W)
        p_img = p_img.transpose(2, 0, 1)

        extra_dict = {}

        return p_img, p_gt, p_cgt, extra_dict


class ValPre(object):  # Validation runs pre processing
    def __call__(self, img, gt):
        # gt = gt - 1
        extra_dict = {}
        return img, gt, None, extra_dict


class TrainValPre(object):  # Validation while Training runs pre processing
    def __init__(self, img_mean, img_std):
        self.img_mean = img_mean
        self.img_std = img_std

    def __call__(self, img, gt=None):
        img = normalize(img, self.img_mean, self.img_std,)
        img = img.transpose(2, 0, 1)
        # gt = gt - 1
        extra_dict = {}
        return img, gt, None, extra_dict  # None here is placeholder for cgt


'''
get_train_loader creates the training DataLoader - parameters passed in are:
- TrainPre (training images pre-processing)
- data setting dictionary with paths to images and GT
- Dataset object

Returns the training DataLoader object (and train sampler but that's only valid for distributed training)

'''


# Data setting is dictionary with some parameters,
def get_train_loader(
        engine,
        dataset,
        train_source,
        unsupervised=False,
        collate_fn=None,
        fully_supervised=False):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)

    if unsupervised is False:
        if fully_supervised is False:
            train_dataset = dataset(data_setting, "train", train_preprocess,
                                    config.max_samples, unsupervised=False)
        else:
            train_dataset = dataset(data_setting, "train", train_preprocess,
                                    config.num_train_imgs, unsupervised=False)
    else:
        train_dataset = dataset(data_setting, "train", train_preprocess,
                                config.max_samples, unsupervised=True)

    train_sampler = None
    is_shuffle = False
    batch_size = config.batch_size

    if engine.distributed is True:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler,
                                   collate_fn=collate_fn)

    return train_loader, train_sampler


class CityScape(BaseDataset):

    # trans_labels = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27,
    #                28, 31, 32, 33]

    trans_labels = [7, 35]  # 35 is random label used as non road label

    def __init__(self, setting, split_name, preprocess=None,
                 file_length=None, training=True, unsupervised=False):
        super(
            CityScape,
            self).__init__(
            setting,
            split_name,
            preprocess,
            file_length)
        self._split_name = split_name
        self._img_path = setting['img_root']
        self._gt_path = setting['gt_root']
        self._train_source = setting['train_source']
        self._eval_source = setting['eval_source']
        self._file_names = self._get_file_names(split_name)
        self._file_length = file_length  # file_length here is the config.max_samples
        # Train and Val preprocess classes above, passed in by the DataLoader
        # when it calls the dataset
        self.preprocess = preprocess
        self.training = training
        self.unsupervised = unsupervised

    def __getitem__(self, index):
        if self._file_length is not None:
            # this bit is inefficient since the filenames are reconstructed
            # everytime the dataloader is called.
            names = self._construct_new_file_names(self._file_length)[index]
        else:
            names = self._file_names[index]
        # - changed to .jpg since images have JPG prefix not PNG, os.path.join(self._img_path, names[0])
        img_path = self._img_path + names[0].split('.')[0] + '.jpg'
        # os.path.join(self._gt_path, names[1])
        gt_path = self._gt_path + names[1]
        item_name = names[1].split("/")[-1].split(".")[0]

        if not self.unsupervised:
            # Image opened using cv2.imread
            img, gt = self._fetch_data(img_path, gt_path)
        else:
            img, gt = self._fetch_data(img_path, None)

        img = img[:, :, ::-1]  # flip third dimension

        if self.preprocess is not None:
            img, gt, edge_gt, extra_dict = self.preprocess(img, gt)
        if gt is not None:
            for i in range(
                    1, 19):  # setting all labels apart from road to 1 (ignore label)
                gt[np.where(gt == i)] = 1
            gt[np.where(gt == 255)] = 1

        if self._split_name in [
            'train',
            'trainval',
            'train_aug',
                'trainval_aug']:
            # image converted to torch array
            img = torch.from_numpy(np.ascontiguousarray(img)).float()
            if gt is not None:
                # contiguous: This function returns an array with at least
                # one-dimension (1-d) so it will not preserve 0-d arrays.
                gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
                if self._split_name != 'trainval':
                    # edge_gt is the semantic edge detector function
                    edge_gt = torch.from_numpy(
                        np.ascontiguousarray(edge_gt)).long()

            # converting items in the dictionary to torch tensors (empty dict
            # is returned from the TrainPre function so should not be None)
            if self.preprocess is not None and extra_dict is not None:
                for k, v in extra_dict.items(
                ):  # iterating thru extra_dict key, value pairs
                    extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
                    if 'label' in k:
                        extra_dict[k] = extra_dict[k].long()
                    if 'img' in k:
                        extra_dict[k] = extra_dict[k].float()

        output_dict = dict(data=img, fn=str(item_name),  # output_dict initialised with 3 items, the image, the name of the image and the length of the train_loader (file_names)
                           n=len(self._file_names))
        if gt is not None:
            # if a label exists, also appends another value to the dictionary
            extra_dict['label'] = gt

        if self.preprocess is not None and extra_dict is not None:
            # appending the extra_dict (gt) to output_dict
            output_dict.update(**extra_dict)

        return output_dict

    def _fetch_data(self, img_path, gt_path=None, dtype=None):
        img = self._open_image(img_path)

        if gt_path is not None:
            gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype)
            return img, gt

        return img, None

    @classmethod
    def get_class_colors(*args):

        # Had to extend to two classes since
        return [[128, 64, 128], [0, 70, 255]]

        '''
        , [244, 35, 232], [70, 70, 70],
                [102, 102, 156], [190, 153, 153], [153, 153, 153],
                [250, 170, 30], [220, 220, 0], [107, 142, 35],
                [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
                [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
                [0, 0, 230], [119, 11, 32]]
        '''

    @classmethod
    def get_class_names(*args):
        return ['road', 'not_road']

        '''
        , 'sidewalk', 'building', 'wall', 'fence', 'pole',
                'traffic light', 'traffic sign',
                'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
                'truck', 'bus', 'train', 'motorcycle', 'bicycle']
        '''

    @classmethod
    def transform_label(cls, pred, name):
        label = np.zeros(pred.shape)
        # returns only once instance of every label in the image
        ids = np.unique(pred)
        print(ids)  # checking whether the model indeed only returns a a value between 0-18 (19 labels) and they have to be converted to the actual label after through this function
        for id in ids:
            # converting the labels from sequential numbers to the actual
            # labels as per the cityscapes dataset spec
            label[np.where(pred == id)] = cls.trans_labels[id]

        new_name = (name.split('.')[0]).split('_')[:-1]
        new_name = '_'.join(new_name) + '.png'

        return label, new_name

Base Dataset

#!/usr/bin/env python3
# encoding: utf-8
# @Time    : 2017/12/16 下午8:41
# @Author  : yuchangqian
# @Contact : changqian_yu@163.com
# @File    : BaseDataset.py

import os
import time
import cv2
import torch
import numpy as np

import torch.utils.data as data


class BaseDataset(data.Dataset):
    def __init__(self, setting, split_name, preprocess=None,
                 file_length=None):
        super(BaseDataset, self).__init__()
        self._split_name = split_name
        self._img_path = setting['img_root']
        self._gt_path = setting['gt_root']
        self._train_source = setting['train_source']
        self._eval_source = setting['eval_source']
        self._file_names = self._get_file_names(split_name)
        self._file_length = file_length
        self.preprocess = preprocess

    def __len__(self):
        if self._file_length is not None:
            return self._file_length
        return len(self._file_names)

    def __getitem__(self, index):
        if self._file_length is not None:
            names = self._construct_new_file_names(self._file_length)[index]
        else:
            names = self._file_names[index]                                     #names takes the required index from the filenames, depending on the file_length variable a new file_names file with a custom length may first be constructed before getting an index
        img_path = os.path.join(self._img_path, names[0])
        gt_path = os.path.join(self._gt_path, names[1])
        item_name = names[1].split("/")[-1].split(".")[0]                       #getting the actual name of the file (not path) of the sample, so for example 'jena_000114_000019_gtFine

        img, gt = self._fetch_data(img_path, gt_path)
  
        img = img[:, :, ::-1]  #flip the channels?
        if self.preprocess is not None:
            img, gt, extra_dict = self.preprocess(img, gt)

        if self._split_name is 'train':
            img = torch.from_numpy(np.ascontiguousarray(img)).float()
            gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
            if self.preprocess is not None and extra_dict is not None:
                for k, v in extra_dict.items():
                    extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v))
                    if 'label' in k:
                        extra_dict[k] = extra_dict[k].long()
                    if 'img' in k:
                        extra_dict[k] = extra_dict[k].float()

        output_dict = dict(data=img, label=gt, fn=str(item_name),
                           n=len(self._file_names))
        if self.preprocess is not None and extra_dict is not None:
            output_dict.update(**extra_dict)

        return output_dict

    def _fetch_data(self, img_path, gt_path, dtype=None):
        img = self._open_image(img_path)
        gt = self._open_image(gt_path, cv2.IMREAD_GRAYSCALE, dtype=dtype)

        return img, gt

    def _get_file_names(self, split_name, train_extra=False):
        assert split_name in ['train', 'val', 'trainval']
        source = self._train_source
        if split_name == 'val' or split_name == 'trainval' :
            source = self._eval_source


        file_names = []
        with open(source) as f:
            files = f.readlines()

        for item in files:
            img_name, gt_name = self._process_item_names(item)
            file_names.append([img_name, gt_name])

        if train_extra:
            file_names2 = []
            source2 = self._train_source.replace('train', 'train_extra')
            with open(source2) as f:
                files2 = f.readlines()

            for item in files2:
                img_name, gt_name = self._process_item_names(item)
                file_names2.append([img_name, gt_name])

            return file_names, file_names2

        return file_names

    def _construct_new_file_names(self, length):                    #max samples depending on whether the labelled or unlabelled images are greater in number
        assert isinstance(length, int)
        files_len = len(self._file_names)                           # 原来一轮迭代的长度 - The length of the original iteration

        # 仅使用小部分数据 - Use only a small portion of the data
        if length < files_len:
            return self._file_names[:length]

        new_file_names = self._file_names * (length // files_len)   # Author - The length of one iteration obtained according to the setting (按照设定获得的一轮迭代的长度)

        rand_indices = torch.randperm(files_len).tolist()           #returns random order of indexes from 0 to n-1 (where n here is the files_len variable)
        new_indices = rand_indices[:length % files_len]             #if length is greater than files_len, we want to add to the new_file_names list, by augmenting random indices hence the += statement at the end, only wants length-files_len indices 



        new_file_names += [self._file_names[i] for i in new_indices]

        return new_file_names

        # the above functionality is needed because of the nature of the training. For every batch we're feeding in an equal number of labelled and unlabelled
        #but obviously due to the labelled ratio this means the unsup and supervised data loaders will have different lengths. Since we need them to be equal, we artificially increase the size of the smaller dataset to match
                #to explain, if supervised samples are 200 and unsupervised are 700, length will be 700 so the new filenames will first multiply the supervised filenames by length // files_len which is 3, so 200*3 is 600. Still an outstanding 100 so 
                #length (700) % files_len (200) will give 100, we take 100 more random indices from the dataset and add - result is filenames file with same length as the unsupervised file

    @staticmethod
    def _process_item_names(item):
        item = item.strip()
        item = item.split('\t')
        img_name = item[0]

        if len(item) == 1:
            gt_name = None
        else:
            gt_name = item[1]

        return img_name, gt_name

    def get_length(self):
        return self.__len__()

    @staticmethod
    def _open_image(filepath, mode=cv2.IMREAD_COLOR, dtype=None):
        # cv2: B G R
        # h w c
        img = np.array(cv2.imread(filepath, mode), dtype=dtype)

        return img

    @classmethod
    def get_class_colors(*args):
        raise NotImplementedError

    @classmethod
    def get_class_names(*args):
        raise NotImplementedError


if __name__ == "__main__":
    data_setting = {'img_root': '',
                    'gt_root': '',
                    'train_source': '',
                    'eval_source': ''}
    bd = BaseDataset(data_setting, 'train', None)
    print(bd.get_class_names())

img_utils:

import cv2
import numpy as np
import numbers
import random
import collections


def get_2dshape(shape, *, zero=True):
    if not isinstance(shape, collections.Iterable):
        shape = int(shape)
        shape = (shape, shape)
    else:
        h, w = map(int, shape)
        shape = (h, w)
    if zero:
        minv = 0
    else:
        minv = 1

    assert min(shape) >= minv, 'invalid shape: {}'.format(shape)
    return shape


def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value):
    h, w = img.shape[:2]  # get size of first two dimensions
    start_crop_h, start_crop_w = crop_pos
    assert ((start_crop_h < h) and (start_crop_h >= 0))
    assert ((start_crop_w < w) and (start_crop_w >= 0))

    crop_size = get_2dshape(crop_size)
    crop_h, crop_w = crop_size

    img_crop = img[start_crop_h:start_crop_h + crop_h,
                   start_crop_w:start_crop_w + crop_w, ...]

    img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT,
                                      pad_label_value)

    return img_, margin


def generate_random_uns_crop_pos(h, w, crop_h, crop_w):

    pos_h, pos_w = 0, 0
    if h > crop_h:
        pos_h = random.randint(0, h - crop_h + 1)
    if w > crop_w:
        pos_w = random.randint(0, w - crop_w + 1)

    return pos_h, pos_w


def generate_random_crop_pos(
        ori_size,
        crop_size,
        unsupervised=False,
        index=None,
        uns_crops=None):

    ori_size = get_2dshape(ori_size)
    h, w = ori_size  # get original image size

    crop_size = get_2dshape(crop_size)  # get crop size
    crop_h, crop_w = crop_size

    pos_h, pos_w = 0, 0

    if h > crop_h:
        pos_h = random.randint(0, h - crop_h + 1)

    if w > crop_w:
        pos_w = random.randint(0, w - crop_w + 1)

    return pos_h, pos_w


def pad_image_to_shape(img, shape, border_mode, value):
    margin = np.zeros(4, np.uint32)
    shape = get_2dshape(shape)
    pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
    pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0

    margin[0] = pad_height // 2
    margin[1] = pad_height // 2 + pad_height % 2
    margin[2] = pad_width // 2
    margin[3] = pad_width // 2 + pad_width % 2

    img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3],
                             border_mode, value=value)

    return img, margin


def pad_image_size_to_multiples_of(img, multiple, pad_value):
    h, w = img.shape[:2]
    d = multiple

    def canonicalize(s):
        v = s // d
        return (v + (v * d != s)) * d

    th, tw = map(canonicalize, (h, w))

    return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value)


def resize_ensure_shortest_edge(img, edge_length,
                                interpolation_mode=cv2.INTER_LINEAR):
    assert isinstance(edge_length, int) and edge_length > 0, edge_length
    h, w = img.shape[:2]
    if h < w:
        ratio = float(edge_length) / h
        th, tw = edge_length, max(1, int(ratio * w))
    else:
        ratio = float(edge_length) / w
        th, tw = max(1, int(ratio * h)), edge_length
    img = cv2.resize(img, (tw, th), interpolation_mode)

    return img


def random_scale(img, gt, scales):
    scale = random.choice(scales)
    sh = int(img.shape[0] * scale)
    sw = int(img.shape[1] * scale)
    img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
    gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)

    return img, gt, scale


def random_scale_with_length(img, gt, length):
    size = random.choice(length)
    sh = size
    sw = size
    img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR)
    gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_NEAREST)

    return img, gt, size


def random_mirror(img, gt):
    if random.random() >= 0.5:
        img = cv2.flip(img, 1)
        gt = cv2.flip(gt, 1)

    return img, gt,


def random_rotation(img, gt):
    angle = random.random() * 20 - 10
    h, w = img.shape[:2]
    rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
    img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR)
    gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)

    return img, gt


def random_gaussian_blur(img):
    gauss_size = random.choice([1, 3, 5, 7])
    if gauss_size > 1:
        # do the gaussian blur
        img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0)

    return img


def center_crop(img, shape):
    h, w = shape[0], shape[1]
    y = (img.shape[0] - h) // 2
    x = (img.shape[1] - w) // 2
    return img[y:y + h, x:x + w]


def random_crop(img, gt, size):
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        size = size

    h, w = img.shape[:2]
    crop_h, crop_w = size[0], size[1]

    if h > crop_h:
        x = random.randint(0, h - crop_h + 1)
        img = img[x:x + crop_h, :, :]
        gt = gt[x:x + crop_h, :]

    if w > crop_w:
        x = random.randint(0, w - crop_w + 1)
        img = img[:, x:x + crop_w, :]
        gt = gt[:, x:x + crop_w]

    return img, gt


def normalize(img, mean, std):
    # pytorch pretrained model need the input range: 0-1
    img = img.astype(np.float32) / 255.0
    img = img - mean
    img = img / std

    return img

metric.py

# encoding: utf-8

import numpy as np

np.seterr(divide='ignore', invalid='ignore')


# voc cityscapes metric
def hist_info(n_cl, pred, gt):
    assert (pred.shape == gt.shape)
    k = (gt >= 0) & (gt < n_cl)  # return boolean array where all conditions matched are true and others as false? same size as gt - done to not take into account the ignore ondex as seen below with the labelled sum
    # k is a list of length 1024, where every element is a list of length
    # 2048, it is effectively as 1024 x 2048 array (in the case of whole image
    # eval)
    labeled = np.sum(k)
    correct = np.sum((pred[k] == gt[k]))
    # print('gt[k] shape', gt[k].shape)  #one dimensional scalar
    # print('pred[k] shape', pred[k].shape)  #one dimensional scalar
    return np.bincount(n_cl * gt[k].astype(int) + pred[k].astype(int),  # return bins/histogram with number of occurunces per label (where the label corresponds to the index of the returned histogram array)
                       # this constructs the confusion matrix in a weird way,
                       # effectively it's giving every element in the matrix an
                       # index, so 0,0 is 0, 0,1 is 1, 0,2 is 2,
                       minlength=n_cl ** 2).reshape(n_cl,
                                                    n_cl), labeled, correct

    # multiplying gt by number of classes means the index get scaled by 3. Now what happens is they are scaled by 3 then added to the non scaled classes. So for example in the case of two classes multiplying the labels [0,1,1,1,0,0] by two before adding them means you will get 0, 1, 2, 3. In this case, a 3 is derived only if the gt label was 1 and the pred was also 1, similarly for all combinations there is also only one way to deerive them. This means effectively every index of a confusion matrix will correspond to one of these indices. A 0 label GT and 1 label pred is an index 1. We then do a bin count to
    # get the numbers of each occurence of each index, then reshape it into a
    # matrix (square form). np.reshape looks at the shape then fills in the
    # firts row first, so 0 and 1, then moves on under in a new line to fill 2
    # and 3. 3 is label-1 GT and label-1 Pred so it is at the end of the
    # diagonal, which is correct for a confusion matrix in this case.


def compute_score(hist, correct, labeled):
    # np.diag extracts all the diagonal elemennts, for a confusion matrix, the
    # diagonal elements would be the correctly labelled classes
    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
    # in this histogram, the images on the columns are the predictions,
    # whereas the images on the rows are the ground truth
    # hist.sum(1) and hist.sum(0) give you false positives and negatives, it's
    # a confousion matrix (however you then need to subtract all the diagnial
    # pixels since those are true positives)
    mean_IU = np.nanmean(iu)
    # this is just the mean iou exlcuding the first class (road in this case)
    mean_IU_no_back = np.nanmean(iu[1:])
    freq = hist.sum(1) / hist.sum()
    freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
    mean_pixel_acc = correct / labeled

    return iu, mean_IU, mean_IU_no_back, mean_pixel_acc


def recall_and_precision(pred, gt, n_cl):
    assert (pred.shape == gt.shape)
    k = (gt >= 0) & (gt < n_cl)  # return boolean array where all conditions matched are true and others as false? same size as gt - done to not take into account the ignore ondex as seen below with the labelled sum
    # k is a list of length 1024, where every element is a list of length
    # 2048, it is effectively as 1024 x 2048 array (in the case of whole image
    # eval)
    precision = [0] * n_cl
    recall = [0] * n_cl
    mean_prec = 0
    mean_recall = 0
    for i in range(n_cl):
        tp = np.sum((pred[k] == i) & (gt[k] == i))
        fp = np.sum((pred[k] == i) & (gt[k] != i))
        #tn = np.sum((pred[k] != i) & (gt[k] != i))
        fn = np.sum((pred[k] != i) & (gt[k] == i))
        # to avoid NaNs if class doesn't exist in picture or not classified
        recall[i] = tp / ((tp + fn) + 1e-10)
        precision[i] = tp / ((tp + fp) + 1e-10)
    mean_prec = sum(precision) / len(precision)
    mean_recall = sum(recall) / len(recall)
    return precision, mean_prec, recall, mean_recall

# ade metric


def meanIoU(area_intersection, area_union):
    iou = 1.0 * np.sum(area_intersection, axis=1) / np.sum(area_union, axis=1)
    meaniou = np.nanmean(iou)
    meaniou_no_back = np.nanmean(iou[1:])

    return iou, meaniou, meaniou_no_back


def intersectionAndUnion(imPred, imLab, numClass):
    # Remove classes from unlabeled pixels in gt image.
    # We should not penalize detections in unlabeled portions of the image.
    imPred = imPred * (imLab >= 0)

    # Compute area intersection:
    intersection = imPred * (imPred == imLab)
    (area_intersection, _) = np.histogram(intersection, bins=numClass,
                                          range=(1, numClass))

    # Compute area union:
    (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
    (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
    area_union = area_pred + area_lab - area_intersection

    return area_intersection, area_union


def mean_pixel_accuracy(pixel_correct, pixel_labeled):
    mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (
        np.spacing(1) + np.sum(pixel_labeled))

    return mean_pixel_accuracy


def pixelAccuracy(imPred, imLab):
    # Remove classes from unlabeled pixels in gt image.
    # We should not penalize detections in unlabeled portions of the image.
    pixel_labeled = np.sum(imLab >= 0)
    pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
    pixel_accuracy = 1.0 * pixel_correct / pixel_labeled

    return pixel_accuracy, pixel_correct, pixel_labeled

init_func.py

#!/usr/bin/env python3
# encoding: utf-8
# @Time    : 2018/9/28 下午12:13
# @Author  : yuchangqian
# @Contact : changqian_yu@163.com
# @File    : init_func.py.py
import math
import torch
import torch.nn as nn
from seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp


def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
                  **kwargs):
    for name, m in feature.named_modules():
        if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            conv_init(m.weight, **kwargs)
        elif isinstance(m, Conv2_5D_depth):
            conv_init(m.weight_0, **kwargs)
            conv_init(m.weight_1, **kwargs)
            conv_init(m.weight_2, **kwargs)
        elif isinstance(m, Conv2_5D_disp):
            conv_init(m.weight_0, **kwargs)
            conv_init(m.weight_1, **kwargs)
            conv_init(m.weight_2, **kwargs)
        elif isinstance(m, norm_layer):
            m.eps = bn_eps
            m.momentum = bn_momentum
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
                **kwargs):
    if isinstance(module_list, list):
        for feature in module_list:
            __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
                          **kwargs)
    else:
        __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
                      **kwargs)


def group_weight(weight_group, module, norm_layer, lr):
    group_decay = []
    group_no_decay = []
    for m in module.modules():
        if isinstance(m, nn.Linear):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
            group_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, Conv2_5D_depth):
            group_decay.append(m.weight_0)
            group_decay.append(m.weight_1)
            group_decay.append(m.weight_2)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, Conv2_5D_disp):
            group_decay.append(m.weight_0)
            group_decay.append(m.weight_1)
            group_decay.append(m.weight_2)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \
                or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm):
            if m.weight is not None:
                group_no_decay.append(m.weight)
            if m.bias is not None:
                group_no_decay.append(m.bias)
        elif isinstance(m, nn.Parameter):
            group_decay.append(m)
        elif isinstance(m, nn.Embedding):
            group_decay.append(m)
        # else:
        #     print(m, norm_layer)
    # print(module.modules)
    # print( len(list(module.parameters())) , 'HHHHHHHHHHHHHHHHH',  len(group_decay) + len(
    #    group_no_decay))
    # making sure the returned list of weight decay and non weight decayed
    # parameters matches the number of all trainable parameters in the model
    assert len(list(module.parameters())) == len(
        group_decay) + len(group_no_decay)
    weight_group.append(dict(params=group_decay, lr=lr))
    weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
    return weight_group


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on
    # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.", stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

Last one needed I think but please let me know if there are still missing class definitions:

loss_opr.py

import numpy as np
import scipy.ndimage as nd

import torch
import torch.nn as nn
import torch.nn.functional as F

from engine.logger import get_logger

logger = get_logger()


class FocalLoss2d(nn.Module):
    def __init__(
            self,
            gamma=0,
            weight=None,
            reduction='mean',
            ignore_index=255):
        super(FocalLoss2d, self).__init__()
        self.gamma = gamma
        if weight:
            self.loss = nn.NLLLoss(
                weight=torch.from_numpy(
                    np.array(weight)).float(),
                reduction=reduction,
                ignore_index=ignore_index)
        else:
            self.loss = nn.NLLLoss(
                reduction=reduction,
                ignore_index=ignore_index)

    def forward(self, input, target):
        return self.loss((1 - F.softmax(input, 1))**2 *
                         F.log_softmax(input, 1), target)


class FocalMSE(nn.Module):
    def __init__(self, gamma=2):
        super(FocalMSE, self).__init__()
        self.gamma = gamma

        self.loss = nn.MSELoss(reduction='none')

    def forward(self, pred, target):
        loss_no_reduction = self.loss(pred, target)
        weight = (1 - pred)**self.gamma
        weighted_loss = torch.mean(loss_no_reduction * weight)
        return weighted_loss


class RCELoss(nn.Module):
    def __init__(
            self,
            ignore_index=255,
            reduction='mean',
            weight=None,
            class_num=37,
            beta=0.01):
        super(RCELoss, self).__init__()
        self.beta = beta
        self.class_num = class_num
        self.ignore_label = ignore_index
        self.reduction = reduction
        self.criterion = nn.NLLLoss(
            reduction=reduction,
            ignore_index=ignore_index,
            weight=weight)
        self.criterion2 = nn.NLLLoss(
            reduction='none',
            ignore_index=ignore_index,
            weight=weight)

    def forward(self, pred, target):
        b, c, h, w = pred.shape
        max_pred, max_id = torch.max(pred, dim=1)		# pred (b, h, w)
        target_flat = target.view(b, 1, h, w)
        mask = (target_flat.ne(self.ignore_label)).float()
        target_flat = (mask * target_flat.float()).long()
        # convert to onehot
        label_pred = torch.zeros(
            b, self.class_num, h, w).cuda().scatter_(
            1, target_flat, 1)
        # print(label_pred.shape, max_id.shape)

        prob = torch.exp(pred)
        prob = F.softmax(prob, dim=1)      # i add this

        weighted_pred = F.log_softmax(pred, dim=1)
        loss1 = self.criterion(weighted_pred, target)

        label_pred = torch.clamp(label_pred, min=1e-9, max=1.0 - 1e-9)

        label_pred = torch.log(label_pred)
        loss2 = self.criterion2(label_pred, max_id)
        loss2 = torch.mean(loss2 * mask)
        # print(loss1, loss2)
        loss = loss1 + self.beta * loss2
        # print(loss1, loss2)
        # print(loss)
        return loss


class BalanceLoss(nn.Module):
    def __init__(self, ignore_index=255, reduction='mean', weight=None):
        super(BalanceLoss, self).__init__()
        self.ignore_label = ignore_index
        self.reduction = reduction
        self.criterion = nn.NLLLoss(
            reduction=reduction,
            ignore_index=ignore_index,
            weight=weight)

    def forward(self, pred, target):
        # prob = torch.exp(pred)
        # # prob = F.softmax(prob, dim=1)      # i add this
        # weighted_pred = pred * (1 - prob) ** 2
        # loss = self.criterion(weighted_pred, target)

        prob = torch.exp(pred)
        prob = F.softmax(prob, dim=1)      # i add this
        weighted_pred = F.log_softmax(pred, dim=1) * (1 - prob) ** 2
        loss = self.criterion(weighted_pred, target)
        return loss


class berHuLoss(nn.Module):
    def __init__(self, delta=0.2, ignore_index=0, reduction='mean'):
        super(berHuLoss, self).__init__()
        self.delta = delta
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, pred, target):
        valid_mask = (1 - target.eq(self.ignore_index)).float()
        valid_delta = torch.abs(pred - target) * valid_mask
        max_delta = torch.max(valid_delta)
        delta = self.delta * max_delta

        f_mask = (1 - torch.gt(target, delta)).float() * valid_mask
        s_mask = (1 - f_mask) * valid_mask
        f_delta = valid_delta * f_mask
        s_delta = ((valid_delta ** 2) + delta ** 2) / (2 * delta) * s_mask

        loss = torch.mean(f_delta + s_delta)
        return loss


class SigmoidFocalLoss(nn.Module):
    def __init__(self, ignore_label, gamma=2.0, alpha=0.25,
                 reduction='mean'):
        super(SigmoidFocalLoss, self).__init__()
        self.ignore_label = ignore_label
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, pred, target):
        b, h, w = target.size()
        pred = pred.view(b, -1, 1)
        pred_sigmoid = pred.sigmoid()
        target = target.view(b, -1).float()
        mask = (target.ne(self.ignore_label)).float()
        target = mask * target
        onehot = target.view(b, -1, 1)

        max_val = (-pred_sigmoid).clamp(min=0)

        pos_part = (1 - pred_sigmoid) ** self.gamma * (
            pred_sigmoid - pred_sigmoid * onehot)
        neg_part = pred_sigmoid ** self.gamma * (max_val + (
            (-max_val).exp() + (-pred_sigmoid - max_val).exp()).log())

        loss = -(self.alpha * pos_part + (1 - self.alpha) * neg_part).sum(
            dim=-1) * mask
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss

class ProbOhemCrossEntropy2d(nn.Module):
    def __init__(
            self,
            ignore_label,
            reduction='mean',
            thresh=0.6,
            min_kept=256,
            down_ratio=1,
            use_weight=False):
        super(ProbOhemCrossEntropy2d, self).__init__()
        self.ignore_label = ignore_label
        self.thresh = float(thresh)
        self.min_kept = int(min_kept)
        self.down_ratio = down_ratio
        if use_weight:
            weight = torch.FloatTensor(  # weight is used to weigh the classes especially for unbalanced datsets.
                [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,  # CHANGE to make sure weighted sum of free space and non free space have an equal distribution within the dataset - CityScapes paper contains number of pixels per class, could also write script
                 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
                 1.0865, 1.1529, 1.0507])
            self.criterion = torch.nn.CrossEntropyLoss(
                reduction=reduction, weight=weight, ignore_index=ignore_label)
        else:
            self.criterion = torch.nn.CrossEntropyLoss(
                reduction=reduction, ignore_index=ignore_label)

    def forward(self, pred, target):
        b, c, h, w = pred.size()
        # flattens tensor to one dimension -
        # https://stackoverflow.com/questions/50792316/what-does-1-mean-in-pytorch-view
        target = target.view(-1)
        # .ne is not equals to, https://pytorch.org/docs/stable/generated/torch.ne.html
        valid_mask = target.ne(self.ignore_label)
        # masking out invalid targets to get number of valids
        target = target * valid_mask.long()
        num_valid = valid_mask.sum()
        # print('num_valid', num_valid)                              # summing
        # all valid targets (non label 255)

        #prob_bins = torch.argmax(pred, dim=1).cpu()
        # print(prob_bins)
        #prob_bins = torch.bincount(prob_bins.view(-1).long())
        prob = F.softmax(pred, dim=1)

        prob = (prob.transpose(0, 1)).reshape(c, -1)

        if self.min_kept > num_valid:
            logger.info('Labels: {}'.format(num_valid))
        elif num_valid > 0:
            # the masked fill function changes value of the element where the
            # mask is true to the second input argument (here it's 1 in the
            # False indices since the ~ operator is used).
            # https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_
            prob = prob.masked_fill_(~valid_mask, 1)
            mask_prob = prob[
                target, torch.arange(len(target), dtype=torch.long)]
            threshold = self.thresh
            if self.min_kept > 0:
                index = mask_prob.argsort()
                threshold_index = index[min(len(index), self.min_kept) - 1]
                if mask_prob[threshold_index] > self.thresh:
                    threshold = mask_prob[threshold_index]
                # The probability is less than the threshold value to be dug
                # out - 概率小于阈值的挖出来
                kept_mask = mask_prob.le(threshold)
                target = target * kept_mask.long()
                valid_mask = valid_mask * kept_mask
                # logger.info('Valid Mask: {}'.format(valid_mask.sum()))

        # ~ is the bitwise negation operator, so flips the mask - fills the False indices with the ignore label - 255
        target = target.masked_fill_(~valid_mask, self.ignore_label)
        target = target.view(b, h, w)

        return self.criterion(pred, target)


def bce2d(input, target):
    b, c, h, w = input.size()

    log_p = input.permute(0, 2, 3, 1).contiguous(
    ).view(-1)      # (b, h, w, c) ==> (b, -1)
    target = target.view(-1)

    pos_index = (target == 1)
    neg_index = (target == 0)
    ignore_index = (target > 1)

    weight = torch.zeros(log_p.size()).cuda().float()
    pos_num = pos_index.sum().float()
    neg_num = neg_index.sum().float()
    sum_num = pos_num + neg_num
    weight[pos_index] = neg_num * 1.0 / sum_num
    weight[neg_index] = pos_num * 1.0 / sum_num

    weight[ignore_index] = 0
    # print(weight.max(), pos_num, neg_num)

    loss = F.binary_cross_entropy_with_logits(
        log_p, target.float(), weight, reduction='mean')
    return loss