I run your code and it prints 0, 1, 2, 3, 4, 5, 6, 7, 8, 9.
import torch.nn as nn
import torch.nn.functional as F
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
self.count = 0
def forward(self, x):
weight = self.weight
print(self.count) # always prints 0
self.count += 1
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
net = Conv2d(3, 3, 3)
for _ in range(10):
x = torch.rand(3, 3, 3, 3)
net(x)