Hi,
Is it possible to pass XLA debug flags to torch-xla?
E.g., when using jax, setting XLA_FLAGS=--xla_force_host_platform_device_count=8
will mimic an 8 device / core backend (useful for debugging in CPU-only platforms) such that jax.devices()
will return a list of 8 devices.
An equivalent call to torch_xla.xla_model.get_xla_supported_devices()
would have similar functionality.