Gradient function of a nn.module

Hello all, suppose I have a scalar function f(t,x) that is a nn.Module:

class f(nn.Module):
    def __init__(self):
        super(f, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 1),
        )

    def forward(self, t, x):
        return self.net(x)

I would like to build a nn.module that takes the input (t,x) and outputs the gradient of f with respect to x, roughly like below:

class dfdx(nn.Module):
    def __init__(self):
        super(dfdx, self).__init__()

    def forward(self, t, x):
        return df_dx

In this case, x=[x1, x2] is a 2x1 tensor, so df_dx should be a 2x1 tensor [df/dx1, df/dx2].

I found a similar question, and tried the following

class f(nn.Module):

    def __init__(self):
        super(f, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 1),
        )
    def forward(self, t, x):
        return self.net(x)

    def compute_u_x(self, t, x):
        self.u_x = torch.autograd.functional.jacobian(self, x, create_graph=True)
        self.u_x = torch.squeeze(self.u_x)
        return self.u_x

and hope that

func = f()
func.compute_u_x

would achieve the goal (Though I am not sure if func.compute_u_x is a nn.module). But got the following error:

TypeError: forward() missing 1 required positional argument: 'x'

I understand why I get this error but have no idea how to fix it. Is there any way to solve the problem? Any help is highly appreciated!