Does PyTorch provide a context manager to specify the dtype and device for all new tensors created inside that context? Roughly, I’m looking for something like:
# `input` is an existing `torch.Tensor`
with default_device_and_dtype(input): # <-- is there a context manager like this somewhere?
x = torch.empty(1,2,3)
y = torch.zeros(4,5)
# expected: `x` and `y` are `torch.Tensor`s with the same dtype and device as `input`
I’m not aware of such a context manager, so would you mind creating a feature request on GitHub to discuss this proposal?
Meanwhile, you could copy this util function, which is used in tests and use it in combination with
Thanks for the response. I found the
new_* functions on
Tensor that seem to provide similar functionality (though maybe more verbose). Not sure if a new feature is worth adding given that there’s an alternative. I suppose a context manager would shrink the API surface and make code more consistent…
Might be, but I personally don’t really like too many context managers, as the code becomes quite unreadable after the 5th indentation.
Have also a look at the
torch.*_like methods, e.g.