nn.Relu error when using with GradScaler

I got the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: 
[torch.cuda.HalfTensor [20, 75, 52, 52]], which is output 0 of ReluBackward0, is at version 1; 
expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

If I don’t use nn,Relu in the CNNBlock class below, the model doesn’t get an error. Can anyone help me explain why I am getting the error.

More about the model I use:

  1. Utils:
@torch.jit.script    
def itersection_over_union(box1,box2):

    box1_x1 = box1[...,0:1] - box1[...,2:3] / 2
    box1_y1 = box1[...,1:2] - box1[...,3:4] / 2
    box1_x2 = box1[...,0:1] + box1[...,2:3] / 2
    box1_y2 = box1[...,1:2] + box1[...,3:4] / 2
    box2_x1 = box2[...,0:1] - box2[...,2:3] / 2
    box2_y1 = box2[...,1:2] - box2[...,3:4] / 2
    box2_x2 = box2[...,0:1] + box2[...,2:3] / 2
    box2_y2 = box2[...,1:2] + box2[...,3:4] / 2    
    x1 = torch.max(box1_x1,box2_x1)
    x2 = torch.min(box1_x2,box2_x2)
    y1 = torch.max(box1_y1,box2_y1)
    y2 = torch.min(box1_y2,box2_y2)
    intersection = (x2-x1).clamp(0) * (y2-y1).clamp(0)
    box1_area = abs((box1_x2-box1_x1)*(box1_y2-box1_y1)) 
    box2_area = abs((box2_x2-box2_x1)*(box2_y2-box2_y1))
    return intersection / (box1_area+box2_area-intersection+1e-6)
  1. Loss:
class Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.lambda_obj_coord = 1
        self.lambda_noobj_coord = 10
        self.lamda_class = 1
        self.lambda_box = 10
        
        self.mse_loss = nn.MSELoss()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()
        self.bce_logic_loss = nn.BCEWithLogitsLoss()
    def forward(self,predict, target, anchor):
        obj = target[..., 0] == 1
        noobj = target[...,0] == 0

        no_object_loss = self.bce_logic_loss(predict[...,0:1][noobj],target[...,0:1][noobj])
        
        anchor = anchor.reshape(1,3,1,1,2)
        box_pred = torch.cat([self.sigmoid(predict[...,1:3]),torch.exp(predict[...,3:5])*anchor],dim=-1)
        iou = itersection_over_union(box_pred[obj],target[...,1:5][obj])
        obj_loss = self.bce_logic_loss(self.sigmoid(predict[...,0:1][obj]),iou*predict[...,0:1][obj])
        predict[...,1:3] = self.sigmoid(predict[...,1:3].clone())
        target[...,3:5] = torch.log((1e-16)+(target[...,3:5].clone())/anchor)
        box_loss = self.mse_loss(predict[...,1:5].clone()[obj], target[...,1:5].clone()[obj])
        class_loss = self.cross_entropy(predict[...,5:].clone()[obj],(target[...,5].clone()[obj]).long())
        
        return (
            self.lambda_box * box_loss 
            + self.lambda_obj_coord * obj_loss
            + self.lambda_noobj_coord * no_object_loss
            + self.lamda_class * class_loss
        )
  1. Model:
block_layer = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]
class CNNBlock(nn.Module):
    def __init__(self,input_channel,out_channel,action= True,**kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=input_channel,out_channels=out_channel,bias=False,**kwargs)
        self.batch_norm = nn.BatchNorm2d(out_channel,device="cuda:0")
        self.relu = nn.ReLU(inplace=False)
        self.action = action
    def __getattribute__(self,name):
            return super(CNNBlock, self).__getattribute__(name)
    def forward(self,x):
        if self.action:
            a = self.conv(x)
            b = self.batch_norm(a)
            out_put = self.relu(b)
        else:
            out_put = self.conv(x)
        return out_put
class ResidualBlock(nn.Module):
    def __init__(self,num_channel,num_repeat,residual=True):
        super().__init__()
        self.layer = nn.ModuleList()
        self.num_repeat = num_repeat
        for step in range(self.num_repeat):
            self.layer  = self.layer + [
                nn.Sequential(
                    CNNBlock(num_channel,num_channel//2,kernel_size=1),
                    CNNBlock(num_channel//2,num_channel,kernel_size=3,padding=1)
                )
            ]

        self.residual = residual
    def __getattribute__(self,name):
            return super(ResidualBlock, self).__getattribute__(name)
    def forward(self,x):
        for layer in self.layer:
            if self.residual:
                x = layer(x) + x
            else:
                x = layer(x)
        return x
class ScalePrediction(nn.Module):
    def __init__(self,input_channel,anchor_per_scale,num_class):
        super().__init__()
        self.modul = nn.Sequential(
            CNNBlock(input_channel,input_channel*2,kernel_size=3,padding=1),
            CNNBlock(2*input_channel,(5+num_class)*3,kernel_size=1)
        )
        self.anchor = anchor_per_scale
        self.num_classes = num_class
    def __getattribute__(self,name):
            return super(ScalePrediction, self).__getattribute__(name)
    def forward(self,x):
        output = self.modul(x)
        return output.reshape(x.shape[0], self.anchor, self.num_classes + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2)
class YoloV3(nn.Module):
    def _create_layer(self):
        layer = nn.ModuleList()
        config_list = self.configs
        input_channels = self.numinput_channel
        for config in config_list:
            if isinstance(config,tuple):
                output_channel, kernnel, stride = config
                layer.append(CNNBlock(input_channels,out_channel=output_channel,kernel_size=kernnel,stride=stride,padding=1 if kernnel==3 else 0))
                input_channels = output_channel
            elif isinstance(config,list):
                num_repeat = config[1]
                layer.append(ResidualBlock(input_channels,num_repeat=num_repeat))
            elif isinstance(config,str):
                if config == "S":
                    layer = layer+[
                        ResidualBlock(input_channels,num_repeat=1,residual=False),
                        CNNBlock(input_channels,input_channels//2,kernel_size=1,action=False),
                        ScalePrediction(input_channels //2 , anchor_per_scale=self.anchor_per_scale, num_class=self.num_class)
                    ]
                    input_channels = input_channels //2
                elif config == "U":
                    layer  = layer+[
                        nn.Upsample(scale_factor=2)
                    ]
                    input_channels = input_channels *3 
        return layer
    def __init__(self,input_channel, num_class, anchor_per_scale, config):
        super().__init__()
        self.configs = config
        self.numinput_channel = input_channel
        self.num_class = num_class
        self.anchor_per_scale = anchor_per_scale
        self.layers  = self._create_layer()
    def forward(self,x):
        output = []
        route = []
        for layer in self.layers:
            if hasattr(layer,"anchor"):
                output.append(layer(x))
            else:
              x = layer(x)
              if  hasattr(layer,"num_repeat"):
                if getattr(layer,"num_repeat") ==8:
                  route.append(x)
              elif hasattr(layer,"align_corners"):
                  x = torch.cat([x,route[-1]],dim=1)
                  route.pop()
        return output

  1. Training
S=[13,26,52]
ANCHORS = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
] 
IMG_SIZE = 416
INPUT_CHANNEL = 3
NUM_CLASS = 20
ANCHOR_PER_SCALE = 3
LEARNING_RATE = 6.1e-5
WEIGHT_DECAY = 1e-6 

model = torch.jit.script(YoloV3(input_channel=INPUT_CHANNEL,num_class=NUM_CLASS,anchor_per_scale=ANCHOR_PER_SCALE,config=block_layer)).to(device)
optim = torch.optim.Adam(model.parameters(),weight_decay=WEIGHT_DECAY,lr=LEARNING_RATE)
loss = torch.jit.script(Loss())
scaler = torch.cuda.amp.GradScaler()
scaled_anchor = (torch.tensor(ANCHORS) * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1,3,2)).to(device)

Example running:

for epoch in range(1):
  for j in range(5):
    data = torch.rand([20, 3, 416, 416]).to(device)
    label1 = torch.rand([20, 3, 13, 13, 6]).to(device)
    label2 = torch.rand([20, 3, 26, 26, 6]).to(device)
    label3 = torch.rand([20, 3, 52, 52, 6]).to(device)
    with torch.cuda.amp.autocast():
      y_pred = model(data)
      los = (
          loss(y_pred[0],label1,scaled_anchor[0])
          + loss(y_pred[1],label2,scaled_anchor[1])
          + loss(y_pred[2],label3,scaled_anchor[2])
      )    
    scaler.scale(los).backward()
    scaler.step(optim)
    scaler.update()
    optim.zero_grad()

Or can take a look at my google colab file about this error: Google Colab

Thanks

You are manipulating the predict tensor inplace here:

predict[...,1:3] = self.sigmoid(predict[...,1:3].clone())

which is disallowed.
You could create a new clone and change it inplace instead:

new_predict = predict.clone()
new_predict[...,1:3] = self.sigmoid(predict[...,1:3].clone())

or use e.g. torch.cat to create the desired result tensor.

1 Like