PyTorch TPU support

Any news on PyTorch TPU support?

Last I heard (https://twitter.com/soumithchintala/status/1009112034242453506).

3 Likes

Was wondering the same… https://twitter.com/plouismarie/status/1048196291132514305

There was an announcement in the PyTorch Developer Conference this week, where a ResNet50 was successfully trained on a TPU.
For more info, check https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud

1 Like

That’s great to read, thank you very much for sharing.

we’re pleased to announce that engineers on Google’s TPU team are actively collaborating with core PyTorch developers to connect PyTorch to Cloud TPUs

There is already some keras code to show how to do TPU in a very concrete and practical way :

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
    )
)
tpu_model.compile(
    optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['sparse_categorical_accuracy']
)

def train_gen(batch_size):
  while True:
    offset = np.random.randint(0, x_train.shape[0] - batch_size)
    yield x_train[offset:offset+batch_size], y_train[offset:offset + batch_size]
    

tpu_model.fit_generator(
    train_gen(1024),
    epochs=10,
    steps_per_epoch=100,
    validation_data=(x_test, y_test),
)

I am hoping to see some pytorch demo codes that enables TPU usage.
I guess it is coming soon. Can not wait !!!

Anyone heard anything new on Pytorch on TPU?

I believe it’s already usable but a little rough around the edges, although I haven’t tried it myself. https://github.com/pytorch/xla

1 Like

Any developments on that in the mean while? XLA is a nice temporary solution, but it would be good to know if we can expect an official solution soon.

Example: https://github.com/pytorch/xla/tree/master/contrib/colab

Thanks to the awesome work by the PyTorch and XLA team, we were able to get TPU support fully working out of the box with PyTorch Lightning.

Check out this nifty guide on going from PyTorch to PyTorch Lightning

Does pytorch TPU supports also float16 precision?

Yes. Moreover, it switches to 16 bit by default.