I’m interested in doing something like this with a backwards pass:
def Model(nn.Module):
def __init__(self, params):
super(Model, self).__init__()
self.layer = nn.Linear(params["input"], params["output"])
self.context_name = params["context"]
def forward(self, x):
return self.layer(x)
def backward(self, x): # modified
with new_context(self.context_name):
<call the regular backwards pass>
Is there a tool with which I could do this? I realize I can put the context around the backwards call, e.g.
model = Model(params)
output = model(torch.ones((1, params["input"])))
with new_context(params["context"]):
output.backward()
but due to the structure of the code I’m working with, this would require duplicating a good amount of code, which I want to avoid if possible. Thank you!
edit 1: forgot to put a context manager around the backward call