Is there something wrong with the backward function?

model = UNet(3, 2)  #input_channels, classes
model.to(device=device)

criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), 0.001)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max')
grad_scaler = torch.cuda.amp.GradScaler(enabled=False)

for epoch in range(epochs):
    model.train()  
    for batch in train_data_loader:
        epoch_loss = 0
        image = batch['image']   #torch.as_tensor(img.copy()).float().contiguous()
        label = batch['label']     #torch.as_tensor(mask.copy()).long().contiguous()
        image = image.to(device=device, dtype=torch.float32)
        label = label.to(device=device, dtype=torch.long)
        with torch.cuda.amp.autocast(enabled=False):
            logits = model(image)
            loss = criterion(logits, label)

        # backprop
        optimizer.zero_grad(set_to_none=True)  
        grad_scaler.scale(loss).backward()
        grad_scaler.step(optimizer)
        grad_scaler.update()
        epoch_loss += loss.item()

    print('epoch', epoch, ' loss:', loss.item())

error info

/tmp/ipykernel_22/2557856581.py in <module>
     34         # backprop
     35         optimizer.zero_grad(set_to_none=True)  # 梯度清0
---> 36         grad_scaler.scale(loss).backward()
     37         grad_scaler.step(optimizer)
     38         grad_scaler.update()

/opt/conda/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

Hi,
Please try following this thread and see if it helps.

thanks. I found the error, the number of categories changed after data pre-processing.