How to have functional_call run other methods different than forward

Suppose I have the following net:

class DiagNormalPolicy(nn.Module):                                                                                                                            
                                                                                                                                                              
    def __init__(self, input_size, output_size, hiddens=None, activation='relu', device='cpu'):                                                               
        super(DiagNormalPolicy, self).__init__()                                                                                                              
        self.device = device                                                                                                                                  
        if hiddens is None:                                                                                                                                   
            hiddens = [100, 100]                                                                                                                              
        if activation == 'relu':                                                                                                                              
            activation = nn.ReLU                                                                                                                              
        elif activation == 'tanh':                                                                                                                            
            activation = nn.Tanh                                                                                                                              
        layers = [linear_init(nn.Linear(input_size, hiddens[0])), activation()]                                                                               
        for i, o in zip(hiddens[:-1], hiddens[1:]):                                                                                                           
            layers.append(linear_init(nn.Linear(i, o)))                                                                                                       
            layers.append(activation())                                                                                                                       
        layers.append(linear_init(nn.Linear(hiddens[-1], output_size)))                                                                                       
        self.mean = nn.Sequential(*layers)                                                                                                                    
        self.sigma = nn.Parameter(torch.Tensor(output_size))                                                                                                  
        self.sigma.data.fill_(math.log(1))                                                                                                                    
                                                                                                                                                              
    def density(self, state):                                                                                                                                 
        state = state.to(self.device, non_blocking=True)                                                                                                      
        loc = self.mean(state)                                                                                                                                
        scale = torch.exp(torch.clamp(self.sigma, min=math.log(EPSILON)))                                                                                     
        return Normal(loc=loc, scale=scale)                                                                                                                   
                                                                                                                                                              
    def log_prob(self, state, action):                                                                                                                        
        density = self.density(state)                                                                                                                         
        return density.log_prob(action).mean(dim=1, keepdim=True)                                                                                             
                                                                                                                                                              
    def forward(self, state):                                                                                                                                 
        density = self.density(state)                                                                                                                         
        action = density.sample()                                                                                                                             
        return action                  

As you can see, it has 3 different methods: density, log_prob, and forward.

After creating an instance I can call functional_call and that runs the forward method as follows:

from torch.nn.utils.stateless import functional_call

policy = DiagNormalPolicy(17, 6)
data = torch.rand.rand(1, 17)
val = functional_call(policy, policy.state_dict(), data)

That works beautifully, but is it possible to run the other two methods with the functional_call? It’d be something like

density = functional_call(policy.density, policy.state_dict(), data)

when I try that, I get the error: AttributeError: 'method' object has no attribute '_attr_to_path'. I get the same error when trying to run log_prob. Or is my only option to just have one method that returns all 3 objects returned by each of the methods?