Tf py_function equivalent in pytorch

Tensorflow’s py_function wraps a python function into a TensorFlow op and allows autograd through the graph. Is there an equivalent function in Pytorch - where a numpy tensor is passed into an arbitrary function and wrapped as a Pytorch op.

Thanks.

2 Likes

Hello There,

There is an equivalent of this functionality in PyTorch.
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
This article shows how to define a custom function with autograd tracking. You have to define the forward and backward pass of the function in a custom class inheriting torch.autograd.Function and type in ‘your_custom_function.apply’ before its usage to use it as a regular function.

Hope this helps,
All the best.

2 Likes

Yes, that works. These scripts are also relevant: https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html

Thanks!

1 Like