Thanks for getting back to me so quickly. Wow, looks like im on version 0.4.1 so i’ll upgrade shortly. This is my implementation, perhaps I misunderstood the examples?
class TanhControl(T.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return calcForward(input)
@staticmethod
def calcForward(input):
return T.tanh(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = grad_output.clone()
grad_input = calcBackward(input)
return grad_input
@staticmethod
def calcBackward(input):
return 1 - pow(T.tanh(input),2)