I am using Fx graph to visualize the operation steps. For “x = self.bn1(x)” operation, the Fx graph shows that this operation consists of multiple operations (below snippet). How to get those intermediate operations for BatchNorm. I tried to debug the code but failed.
# File: miniconda3/envs/pytorch/lib/python3.11/site-packages/torchvision/models/resnet.py:269, code: x = self.bn1(x)
add: "i64[]" = torch.ops.aten.add.Tensor(primals_164, 1)
var_mean = torch.ops.aten.var_mean.correction(convolution, [0, 2, 3], correction = 0, keepdim = True)
getitem: "f32[1, 64, 1, 1]" = var_mean[0]
getitem_1: "f32[1, 64, 1, 1]" = var_mean[1]; var_mean = None
add_1: "f32[1, 64, 1, 1]" = torch.ops.aten.add.Tensor(getitem, 1e-05)
rsqrt: "f32[1, 64, 1, 1]" = torch.ops.aten.rsqrt.default(add_1); add_1 = None
sub: "f32[1, 64, 112, 112]" = torch.ops.aten.sub.Tensor(convolution, getitem_1)
mul: "f32[1, 64, 112, 112]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = None
squeeze: "f32[64]" = torch.ops.aten.squeeze.dims(getitem_1, [0, 2, 3]); getitem_1 = None
squeeze_1: "f32[64]" = torch.ops.aten.squeeze.dims(rsqrt, [0, 2, 3]); rsqrt = None
mul_1: "f32[64]" = torch.ops.aten.mul.Tensor(squeeze, 0.1)
mul_2: "f32[64]" = torch.ops.aten.mul.Tensor(primals_162, 0.9)
add_2: "f32[64]" = torch.ops.aten.add.Tensor(mul_1, mul_2); mul_1 = mul_2 = None
squeeze_2: "f32[64]" = torch.ops.aten.squeeze.dims(getitem, [0, 2, 3]); getitem = None
mul_3: "f32[64]" = torch.ops.aten.mul.Tensor(squeeze_2, 1.0000797257434426); squeeze_2 = None
mul_4: "f32[64]" = torch.ops.aten.mul.Tensor(mul_3, 0.1); mul_3 = None
mul_5: "f32[64]" = torch.ops.aten.mul.Tensor(primals_163, 0.9)
add_3: "f32[64]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
unsqueeze: "f32[64, 1]" = torch.ops.aten.unsqueeze.default(primals_2, -1)
unsqueeze_1: "f32[64, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, -1); unsqueeze = None
mul_6: "f32[1, 64, 112, 112]" = torch.ops.aten.mul.Tensor(mul, unsqueeze_1); mul = unsqueeze_1 = None
unsqueeze_2: "f32[64, 1]" = torch.ops.aten.unsqueeze.default(primals_3, -1); primals_3 = None
unsqueeze_3: "f32[64, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1); unsqueeze_2 = None
add_4: "f32[1, 64, 112, 112]" = torch.ops.aten.add.Tensor(mul_6, unsqueeze_3); mul_6 = unsqueeze_3 = None