Problems in merging customized neuron into standard ones

I created a customized autograd function and wrap it inside a module. But I just keep getting errors when I try to merge my module with other standard modules using nn.Sequential(). Can anyone help? Thanks.

import torch
import torch.nn as nn

# customized autograd function
class BinaryTanh_func(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        tanh_func = nn.Tanh()
        preactivated_threshold = tanh_func(input)
        final_output = torch.ones(preactivated_threshold.size())
        final_output[preactivated_threshold < 0] = 0
        return final_output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        tanh_func = nn.Tanh()
        preactivated = tanh_func(input)
        grad_input = (1-preactivated**2) * grad_input
        return grad_input

# customized module
class BinaryTanh(nn.Module):
    @staticmethod
    def forward(self, input):
        return BinaryTanh_func()(input)


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

torch.manual_seed(100)
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    BinaryTanh(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(10):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    model.zero_grad()
    loss.backward()

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

Here’s the error message I got:

TypeError Traceback (most recent call last)

<ipython-input-29-becf74c80a68> in <module>()
     73 for t in range(10):
     74 
---> 75     y_pred = model(x)
     76 
     77     loss = loss_fn(y_pred, y)

2 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

TypeError: forward() missing 1 required positional argument: 'input'

Any help will be appreciated. Thanks.

Try to change your BinaryTanh definition to:

# customized module
class BinaryTanh(nn.Module):
    def __init__(self):
        super().__init__()
        self.fn = BinaryTanh_func.apply
        
    def forward(self, input):
        return self.fn(input)
1 Like

Thanks a lot!! It works.
And I also found that if I comment the ‘’@staticmethod" line in my original code, it will work too. Can you briefly explain the reason? I think I will dig deeper into this and try to google more about the “staticmethod” thing. Thanks again.

staticmethods are bound to the class and do not require a class instance. They are therefore independent of the “state” of a class instance.
So instead of creating an instance (attributes) and calling it, you just might call MyFunction.apply directly.

1 Like

Got it. Thank you very much!