Call activation function from string

Hi,

Is there a way to call an activation function from a string?
For example something like this :

activation_string = "relu"
activation_function = nn.activation(activation_string)
u = activation_function(v)

It would be really practical to have something like this, for example to define the activation function in a config file, instead of inside the classes.

Thanks in advance,
Manu

3 Likes

you can create a dictionary with several activation function mappings:

activations = {
    'relu': nn.ReLU()
    'sigmoid': nn.Sigmoid(),
    'tanh': nn.Tanh()
}

and then call

activation_function = activations[activation_string]
u = activation_function(v)
2 Likes

Thank you for your answer.
I considered this option before asking the question but I just wanted to know if there was any built-in implementation before reinventing the wheel.
Thanks :wink:

1 Like

You can dynamically access the method by using getattr:

activation_string = "relu"
activation_function = getattr(nn, activation_string)()  
u = activation_function(v)

Hi, just in case that there is still a need for this functionality … I wrote the following helper function that parses out the export variable (all) in submodule torch.nn.modules.activation:

from torch.nn import Module
from typing import Callable
from torch.nn.modules import activation

Activation = Callable[..., Module]


def get_activation_fn(act: str) -> Activation:
    # get list from activation submodule as lower-case
    activations_lc = [str(a).lower() for a in activation.__all__]
    if (act := str(act).lower()) in activations_lc:
        # match actual name from lower-case list, return function/factory
        idx = activations_lc.index(act)
        act_name = activation.__all__[idx]
        act_func = getattr(activation, act_name)
        return act_func
    else:
        raise ValueError(f"Cannot find activation function for string <{act}>")

Note that this function intentionally returns a reference to a torch.nn.Module, which still has to be instatiated. This allows to use the output of this function as a factory and with additional arguments, e.g.:

act1 = get_activation_fn('relu')
print(act1())
ReLU()

act2 = get_activation_fn('gelu')
print(act2())
GELU(approximate='none')

act3 = get_activation_fn('elu')
print(act3(alpha=1.2))
ELU(alpha=1.2)

Hope it helps!