I am new to pytorch
class RN50_End(torch.nn.Module):
def init(self, rn50):
super(RN50_End, self).init()
self.avgpool = rn50.avgpool
self.fc = rn50.fc
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
end_x = torch.rand([112, 2048, 7, 7 ], dtype=torch.float32).contiguous()
r50end = RN50_End(rn50)
for i, (images, target) in enumerate(cal_loader):
output = r50end(images)