Modify Loaded TorchScript Model

I have a torchscript model that I lost its source code, so I cant create the architecture from scratch. I want to retrain this model. A classic train loop is actually working fine but the model was exported in eval mode which causes batchnorms not to track activation stats through the training (their weights and biases are still being optimezed wrt to the loss) .

I am trying to instantiate new batchnorm layers instead of scripted ones as below but still no luck.

model = torch.jit.load("scripted_model.pt")

def replace_bn(m):
    if m.original_name == "BatchNorm2d":
        m = torch.nn.BatchNorm2d(module.weight.shape[0])

model.apply(replace_bn)

How can I have a fully functioning training?