I want to get the intermediate features for calculating the perceptual loss, but get the OOM error, blow is my code:
class Vgg19(torch.nn.Module):
def __init__(self):
super(Vgg19, self).__init__()
features = list(vgg19(pretrained=True).features)
self.features = torch.nn.ModuleList(features).eval()
def forward(self, x):
results = []
for ii, model in enumerate(self.features):
x = model(x)
if ii in {1, 6, 11, 20, 29}:
results.append(x)
return results
vgg = Vgg().to(torch.device("cuda"))
features_list = vgg(x)
when I get the intermediate features by above-mentioned code, I get the OOM error? why?
I train this model with 4 Nvidia v100-16G cards and use distributedataparallel mode.