I created a customized autograd function and wrap it inside a module. But I just keep getting errors when I try to merge my module with other standard modules using nn.Sequential(). Can anyone help? Thanks.
import torch
import torch.nn as nn
# customized autograd function
class BinaryTanh_func(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
tanh_func = nn.Tanh()
preactivated_threshold = tanh_func(input)
final_output = torch.ones(preactivated_threshold.size())
final_output[preactivated_threshold < 0] = 0
return final_output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
tanh_func = nn.Tanh()
preactivated = tanh_func(input)
grad_input = (1-preactivated**2) * grad_input
return grad_input
# customized module
class BinaryTanh(nn.Module):
@staticmethod
def forward(self, input):
return BinaryTanh_func()(input)
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
torch.manual_seed(100)
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
BinaryTanh(),
torch.nn.Linear(H, D_out),
)
# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-4
for t in range(10):
y_pred = model(x)
loss = loss_fn(y_pred, y)
print(t, loss.item())
model.zero_grad()
loss.backward()
with torch.no_grad():
for param in model.parameters():
param -= learning_rate * param.grad
Here’s the error message I got:
TypeError Traceback (most recent call last)
<ipython-input-29-becf74c80a68> in <module>()
73 for t in range(10):
74
---> 75 y_pred = model(x)
76
77 loss = loss_fn(y_pred, y)
2 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
TypeError: forward() missing 1 required positional argument: 'input'
Any help will be appreciated. Thanks.