RuntimeError: Expected isFloatingType(grads[i].scalar_type()) to be true, but got false. (Could this error message be improved?

I am getting this weird error when i am using the following funcs in my training loop as soon as it hits loss.backward()

def pos_weight(pred_tensor, pos_tensor, neg_weight=1, pos_weight=1):
    # neg_weight for when pred position < target position
    # pos_weight for when pred position > target position
    gap = torch.argmax(pred_tensor, dim=1) - pos_tensor
    gap = gap.type(torch.float32)
    return torch.where(gap < 0, -neg_weight * gap, pos_weight * gap)

def loss_fn(start_logits, end_logits, start_positions, end_positions):
    # had verified all tensors are of same torch.float32 dtype, not sure where the error is :(
    loss_fct = nn.CrossEntropyLoss(reduction='none') # do reduction later
    
    start_loss = loss_fct(start_logits, start_positions) * pos_weight(start_logits, start_positions, 1, 1)
    end_loss = loss_fct(end_logits, end_positions) * pos_weight(end_logits, end_positions, 1, 1)
    
    start_loss = torch.mean(start_loss)
    end_loss = torch.mean(end_loss)
    total_loss = (start_loss + end_loss)
    return total_loss
# Dry Run Test on Dummy Values

start = torch.Tensor([[0.1, 0.1, 0.1, 0.8, 0.1]]).float()
start_target = torch.Tensor([1]).long()

end = torch.Tensor([[0.1, 0.1, 0.1, 0.8, 0.1]]).float()
end_target = torch.Tensor([3]).long()
loss_fn(start, end, start_target, end_target) # returns tensor(3.5881)

System Config -:

pytorch_version is 1.5.0
on Ubuntu

cat /usr/local/cuda/version.txt returns the below,

CUDA Version 9.2.148
CUDA Patch Version 9.2.148.1

nvidia-smi returns the below,

NVIDIA-SMI 410.79 Driver Version: 410.79 CUDA Version: 10.0

Complete Traceback,

RuntimeError                              Traceback (most recent call last)
<ipython-input-20-fd990f79dfa6> in <module>
      1 torch.cuda.empty_cache()
----> 2 run(fold=0)

<ipython-input-18-114e1832c8bf> in run(fold)
     80 
     81     for epoch in range(config.EPOCHS):
---> 82         train_fn(train_data_loader, model, optimizer, device, scheduler=scheduler)
     83         jaccard = eval_fn(valid_data_loader, model, device)
     84         print(f"Jaccard Score = {jaccard}")

<ipython-input-13-aef75b9df068> in train_fn(data_loader, model, optimizer, device, scheduler)
     40         loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end, class_preds, class_label)
     41 
---> 42         loss.backward()
     43         # model = model.float() found on SO, old discussion post here for this error;
     44         optimizer.step()

~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102

RuntimeError: Expected isFloatingType(grads[i].scalar_type()) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.) (validate_outputs at /opt/conda/conda-bld/pytorch_1587428111115/work/torch/csrc/autograd/engine.cpp:476)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x4e (0x7fc24cc2db5e in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x2ae32c7 (0x7fc2713782c7 in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x548 (0x7fc271379368 in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7fc27137b2f2 in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int) + 0x39 (0x7fc271373969 in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7fc2746ba548 in /root/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xc819d (0x7fc2a568819d in /root/anaconda3/lib/python3.7/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6)
frame #7: <unknown function> + 0x76ba (0x7fc2a887b6ba in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x6d (0x7fc2a85b141d in /lib/x86_64-linux-gnu/libc.so.6)

Also, if i use this variant of the loss_fn, then no error as such happens during training loop;

def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss   = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

Hi,

Do you get a nice python stack trace or just the cpp one?

Updated the traceback;

Do you have custom autograd Function in your code?
It seems that one of them returned a gradient which is integer type while gradients should always be floating point types.
Enabling anomaly mode might help you pinpoint the faulty Function.

No i don’t have this Alban;

Enabling anomaly mode might help you pinpoint the faulty Function.

Can you share how do we enable that? (Will post back what i found!)

Thanks;

Can you share how do we enable that?

Ho yes sure.
torch.autograd.set_detect_anomaly(True)

No i don’t have this Alban;

Ok, And do you manually provide grad_output to .backward() or autograd.grad() ?

And do you manually provide grad_output to .backward() or autograd.grad() ?

Well i don’t think i am even aware of this :sweat_smile: ; It’s a very simple training loop, nothing complicated at all except the loss_fn which i modified;

Well, same traceback actually;

It is quite weird idd.
Do you think you can get a small code sample (30 lines) that I could run?

I can send you in a private gist, [the whole training pipeline] because code is part of an on-going kaggle comp (hence it’s not small either but easily readable);

Apologies for the inconvenience;

Thanks;

I was wondering if you had something like the dry run example above;

# Dry Run Test on Dummy Values

start = torch.Tensor([[0.1, 0.1, 0.1, 0.8, 0.1]]).float()
start_target = torch.Tensor([1]).long()

end = torch.Tensor([[0.1, 0.1, 0.1, 0.8, 0.1]]).float()
end_target = torch.Tensor([3]).long()
loss_fn(start, end, start_target, end_target) # returns tensor(3.5881)

But loss_fct is missing. Is that the part you cannot share?

But loss_fct is missing. Is that the part you cannot share?

loss_fct is loss_fct = nn.CrossEntropyLoss(reduction='none') (shared above in loss_fn)?
Sorry if i didn’t understand your question.

The problem is that pos_weight is not actually differentiable because the argmax op is not differentiable.
But due to a bug on our side, this is detected too late and here is a minimal repro:

import torch
from torch import nn

t = torch.rand(10, requires_grad=True)
bad = torch.argmax(t)

res = bad + 2

res.sum().backward()

You can fix your code by adding a .detach() to the output of torch.where before returning it in pos_weight.

1 Like

Thanks a bunch! :smile_cat:

Not opening an issue on github then :slight_smile: ?

Well, still a problem. Opened the issue here: https://github.com/pytorch/pytorch/issues/37680

1 Like

Glad i was able to break/find a potential bug in PyTorch! haha just kidding :slight_smile: