Context manager for dtype and device

Does PyTorch provide a context manager to specify the dtype and device for all new tensors created inside that context? Roughly, I’m looking for something like:

# `input` is an existing `torch.Tensor`
with default_device_and_dtype(input):   # <-- is there a context manager like this somewhere?
  x = torch.empty(1,2,3)
  y = torch.zeros(4,5)

# expected: `x` and `y` are `torch.Tensor`s with the same dtype and device as `input`

I’m not aware of such a context manager, so would you mind creating a feature request on GitHub to discuss this proposal?

Meanwhile, you could copy this util function, which is used in tests and use it in combination with with torch.cuda.device(id).

Thanks for the response. I found the new_* functions on Tensor that seem to provide similar functionality (though maybe more verbose). Not sure if a new feature is worth adding given that there’s an alternative. I suppose a context manager would shrink the API surface and make code more consistent…

Might be, but I personally don’t really like too many context managers, as the code becomes quite unreadable after the 5th indentation. :wink:

Have also a look at the torch.*_like methods, e.g. torch.zeros_like(tensor).