How calling an instance of class 'Linear' works if there is no method/function '__call__()' defined in class 'Linear'?


from torch.nn import Linear
model = Linear(in_features = 1, out_features = 1)
y = model(x) # calling an instance of class Linear

Although, ‘Linear’ class inherits ‘Module’ class, which has ‘__ call__()’ defined in it. I could not catch it from there. Kindly help.

As far as I know; In python, to call an instance (here model()) of a class, the class must have ‘__ call__()’ defined in it. But how does this work in the case of class- ‘Linear’?

If a function is not redefined in the class itself, the one from its parent is used. So in this case, the __call__() method from the Module class will be used.

Ok. __call ___() from it’s super class (Module) includes following lines:

    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
        result = self.forward(*input, **kwargs)

Can you please tell me, from where to get information about this torch._C._get_tracing_state() ?
I am attaching the whole __call __() function for your reference:

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
    return result

That would be a cpp function. Most likely defined by the jit given the name.
If you look for it, you can find it defined here in cpp.
But that will only be used during jit tracing. So you can safely assume that it is None outside of this use case.

1 Like

Thank you @albanD for your support.