TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'

I have a relatively large architecture. Here is the training section where I use torch.cuda.amp for my model:

def train(epoch):
    # model in training mode
    modelstate.model.train()
    # initialization
    total_loss = 0
    total_batches = 0
    total_points = 0
    if torch.cuda.is_available():
        
        scaler = torch.cuda.amp.GradScaler()

    for i, (u, y) in enumerate(loader_train):
        u = u.to(device)#torch.Size([B, D_in, T])
        y = y.to(device)
        
        # set the optimizer
                    # set the optimizer
        modelstate.optimizer.zero_grad()
        if torch.cuda.is_available():
            with torch.autocast(device_type='cuda', dtype=torch.float32) and torch.backends.cudnn.flags(enabled=False):
                loss_ = modelstate.model(u, y)
            diff_params = [p for p in modelstate.model.m.parameters() if p.requires_grad]    
            scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss_),
                                                    inputs=diff_params,
                                                    create_graph=True,
                                                    retain_graph=True,
                                                    allow_unused=True #Whether to allow differentiation of unused parameters.
                                                    )
             
            
            #find NaN values in the architecture
            for submodule in modelstate.model.m.modules():
                submodule.register_forward_hook(nan_hook)
                
            inv_scale = 1./scaler.get_scale()
            
            #grad_params = [ p * inv_scale if p is not None and not torch.isnan(p).any() else torch.tensor(0, device=device, dtype=torch.float32) for p in scaled_grad_params ]
            grad_params = [p * inv_scale for p in scaled_grad_params]
            with torch.autocast(device_type='cuda', dtype=torch.float32):
                #grad_norm = torch.tensor(0, device=grad_params[0].device, dtype=grad_params[0].dtype)
                grad_norm = 0
                for grad in grad_params:
                    grad_norm += grad.pow(2).sum()
                    grad_norm = grad_norm**0.5
                # Compute the L2 Norm as penalty and add that to loss
                loss_ = loss_ + grad_norm

Here is the error message :

/tmp/ipykernel_100588/1815332848.py in (.0)
3006
3007 #grad_params = [ p * inv_scale if p is not None and not torch.isnan(p).any() else torch.tensor(0, device=device, dtype=torch.float32) for p in scaled_grad_params ]
→ 3008 grad_params = [p * inv_scale for p in scaled_grad_params]
3009 with torch.autocast(device_type=‘cuda’, dtype=torch.float32):
3010 #grad_norm = torch.tensor(0, device=grad_params[0].device, dtype=grad_params[0].dtype)

TypeError: unsupported operand type(s) for *: ‘NoneType’ and ‘float’

Before, I changed the problematic line with this grad_params = [ p * inv_scale if p is not None and not torch.isnan(p).any() else torch.tensor(0, device=device, dtype=torch.float32) for p in scaled_grad_params ] and I wasn’t getting this error.

However, I wasn’t happy with the results of my model and I read in a debugging blog that I must check the gradient for every parameter and if they come up empty that is where I must take a closer look. The blog didn’t go further why it is a problem. I am wondering how to deal with this tyep of error in general?

The issue is raised in the mentioned line of code since you are using allow_unused=True in the grad call, which will return None gradients for unused parameters as seen here:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        self.unused = nn.Linear(1, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    
model = MyModel()
x = torch.randn(1, 1)
out = model(x)
loss = out.mean()

grads = torch.autograd.grad(outputs=loss,
                            inputs=model.parameters(),
                            create_graph=True,
                            retain_graph=True,
                            allow_unused=True
                            )
print(grads)
# (tensor([[0.2469]]), tensor([1.]), None, None)

You might thus need to check for None gradients before trying to unscale them.