Hi !
You may go for
class CubeFunction(torch.autograd.Function):
"""
Dummy activation function x -> x ** 3
"""
@staticmethod
def forward(ctx, X):
ctx.save_for_backward(X)
return X ** 3
@staticmethod
def backward(ctx, M):
X, = ctx.saved_tensors
return CubeFunctionBackward.apply(X, M)
class CubeFunctionBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, X, M):
ctx.save_for_backward(X, M)
return M * 3 * X ** 2
@staticmethod
def backward(ctx, V):
X, M = ctx.saved_tensors
return V * 6 * X * M, V * 3 * X ** 2
See this gist https://gist.github.com/arthurmensch/b02c45f3440c9d3de0ef2c0ae5ea1107 for a complete demo.
I could PR this to the documentation if any developer deems it interesting enough to be exposed.