Solved it using the function below
def adjust_model(model):
for child_name, child in model.named_children():
if isinstance(child, nn.Conv2d):
causal_padding = (2 * child.padding[1], 0, child.padding[0],
child.padding[0])
child.padding = (0, 0)
setattr(model, child_name, nn.Sequential(
nn.ZeroPad2d(causal_padding),
child
))
else:
adjust_model(child)