Running a backwards pass in a custom context manager

I’m interested in doing something like this with a backwards pass:

def Model(nn.Module):

    def __init__(self, params):
        super(Model, self).__init__()
        self.layer = nn.Linear(params["input"], params["output"])
        self.context_name = params["context"]

    def forward(self, x):
        return self.layer(x)

    def backward(self, x): # modified
        with new_context(self.context_name):
            <call the regular backwards pass>

Is there a tool with which I could do this? I realize I can put the context around the backwards call, e.g.

model = Model(params)
output = model(torch.ones((1, params["input"])))
with new_context(params["context"]):
    output.backward()

but due to the structure of the code I’m working with, this would require duplicating a good amount of code, which I want to avoid if possible. Thank you!

edit 1: forgot to put a context manager around the backward call