Tf py_function equivalent in pytorch

Tensorflow’s py_function wraps a python function into a TensorFlow op and allows autograd through the graph. Is there an equivalent function in Pytorch - where a numpy tensor is passed into an arbitrary function and wrapped as a Pytorch op.



Hello There,

There is an equivalent of this functionality in PyTorch.
This article shows how to define a custom function with autograd tracking. You have to define the forward and backward pass of the function in a custom class inheriting torch.autograd.Function and type in ‘your_custom_function.apply’ before its usage to use it as a regular function.

Hope this helps,
All the best.


Yes, that works. These scripts are also relevant:


1 Like