JIT tracer throws type failures in TensorIterator when default tensor type is CUDA

Most of Pyro’s models and unit tests, e.g. in test_jit.py are now working with the PyTorch JIT tracer.

However, when we run these tests with the default tensor type set to torch.cuda.DoubleTensor, the tests fail with a cryptic error message, pasted below. Note that these errors are only triggered on CUDA with JIT enabled (they run fine as CUDA tests or on the CPU with JIT).

  File "/home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/jit/__init__.py", line 1172, in forward
    return self._get_method('forward')(*args, **kwargs)
RuntimeError: TensorIterator expected type CPUDoubleType but got CUDADoubleType[2] (check_type_conversions at /home/npradhan/workspace/pyro_dev/pytorch/aten/src/ATen/native/TensorIterator.cpp:426)
frame #0: at::TensorIterator::Builder::build() + 0x21b (0x7fd85785f87b in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #1: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&) + 0x80 (0x7fd85785e73a in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #2: at::native::mul_out(at::Tensor&, at::Tensor const&, at::Tensor const&) + 0x114 (0x7fd857762b71 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #3: at::native::mul(at::Tensor const&, at::Tensor const&) + 0x47 (0x7fd857762c3e in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #4: at::TypeDefault::mul(at::Tensor const&, at::Tensor const&) const + 0x52 (0x7fd857a9baf8 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #5: torch::autograd::VariableType::mul(at::Tensor const&, at::Tensor const&) const + 0x3e2 (0x7fd84443749e in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #6: <unknown function> + 0xcd94d0 (0x7fd8445e54d0 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #7: <unknown function> + 0xd17fdc (0x7fd844623fdc in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #8: <unknown function> + 0xd7edd7 (0x7fd84468add7 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #9: std::function<int (std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&)>::operator()(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&) const + 0x49 (0x7fd85b7ff415 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0xe88ee5 (0x7fd844794ee5 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #11: <unknown function> + 0xe88fb6 (0x7fd844794fb6 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #12: torch::jit::ConstantPropagation(torch::jit::Node*, bool) + 0x163 (0x7fd844795b3d in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #13: torch::jit::ConstantPropagation(torch::jit::Block*, bool) + 0xd2 (0x7fd844795c48 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #14: torch::jit::ConstantPropagation(std::shared_ptr<torch::jit::Graph>&) + 0x2d (0x7fd844795c91 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #15: <unknown function> + 0xdff336 (0x7fd84470b336 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #16: <unknown function> + 0xdff1c3 (0x7fd84470b1c3 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #17: <unknown function> + 0xe06272 (0x7fd844712272 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #18: torch::jit::GraphExecutor::run(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&) + 0x2e (0x7fd84470b7b2 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/lib/libtorch.so.1)
frame #19: torch::jit::script::Method::run(std::vector<torch::jit::IValue, std::allocator<torch::jit::IValue> >&) + 0xf6 (0x7fd85b8e4744 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0xd146ba (0x7fd85b8e66ba in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #21: <unknown function> + 0xd44bab (0x7fd85b916bab in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #22: <unknown function> + 0xd3c8f0 (0x7fd85b90e8f0 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #23: <unknown function> + 0xd35f52 (0x7fd85b907f52 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #24: <unknown function> + 0xd360e5 (0x7fd85b9080e5 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #25: <unknown function> + 0x92d0d6 (0x7fd85b4ff0d6 in /home/npradhan/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
<omitting python frames>

I have found it very hard to debug this, and reduce it to a more minimal example, but given that so many of our models (e.g. the VAE), and tests are failing when JIT traced using CUDA tensors, I think there is something more fundamental on our end that is breaking the JIT tracer, or it might point to a bug somewhere in torch/aten. Any help / tips in debugging this is appreciated. Is it possible to look at the trace graph to see why/where the JIT expects a CPU tensor in the first place?