Is there a way to call an activation function from a string?
For example something like this :
activation_string = "relu"
activation_function = nn.activation(activation_string)
u = activation_function(v)
It would be really practical to have something like this, for example to define the activation function in a config file, instead of inside the classes.
Thank you for your answer.
I considered this option before asking the question but I just wanted to know if there was any built-in implementation before reinventing the wheel.
Thanks
Hi, just in case that there is still a need for this functionality … I wrote the following helper function that parses out the export variable (all) in submodule torch.nn.modules.activation:
from torch.nn import Module
from typing import Callable
from torch.nn.modules import activation
Activation = Callable[..., Module]
def get_activation_fn(act: str) -> Activation:
# get list from activation submodule as lower-case
activations_lc = [str(a).lower() for a in activation.__all__]
if (act := str(act).lower()) in activations_lc:
# match actual name from lower-case list, return function/factory
idx = activations_lc.index(act)
act_name = activation.__all__[idx]
act_func = getattr(activation, act_name)
return act_func
else:
raise ValueError(f"Cannot find activation function for string <{act}>")
Note that this function intentionally returns a reference to a torch.nn.Module, which still has to be instatiated. This allows to use the output of this function as a factory and with additional arguments, e.g.: