OOMKilled with exit code 137 with pytorch xla

I am running a small model with small batch size on TPU with 8 cores, and when it calls

xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)

It always goes to out of memory on google cloud, and I am not sure where is the issue, I tried to reduce the batch size as much as possible still the error is there. thanks

Could you provide more details on your setup as it is hard to reproduce it ASIS. CC: @ailzhang on XLA.