Is it possible to pass XLA debug flags to torch-xla?
E.g., when using jax, setting
XLA_FLAGS=--xla_force_host_platform_device_count=8 will mimic an 8 device / core backend (useful for debugging in CPU-only platforms) such that
jax.devices() will return a list of 8 devices.
An equivalent call to
torch_xla.xla_model.get_xla_supported_devices() would have similar functionality.