The single GPU can not slove the many parameters.I have 4 GPUS, every memory is 12G. How i code to solve the problem?
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.models as models
original_model = models.vgg16(pretrained=False)
original_model.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 1024),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(1024, 30),
)
original_model.load_state_dict(torch.load('/home/zj/ST/vgg16_finetune.pth'))
class ST_LSTM(nn.Module):
def __init__(self,num_classes):
super(ST_LSTM,self).__init__()
self.features = nn.Sequential(
# stop at conv4
*list(original_model.features.children())[:-1]
)
# self.features = features
for p in self.features.parameters():
p.requires_grad=False
self.classifier2 = nn.Sequential(
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Dropout(p=0.8),
nn.Linear(1024, 30),
)
self.mm = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(True),
nn.Dropout(p=0.8),
nn.Linear(512, 6),
)
self.lstm = nn.LSTM(14*14*512,1024,batch_first=True)
# self.fc_loc = nn.Sequential(
# nn.Linear(10 * 3 * 3, 32),
# nn.ReLU(True),
# nn.Linear(32, 3 * 2)
# )
# stn is the Spatial Transformer Networks in pytorchTutorial
def stn(self, x, theta):
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
y =self.features(x)
# z1 = self.stn(x)
z1= y.view(-1,1,14*14*512)
self.lstm.flatten_parameters()
r1, (h1, c1) = self.lstm(z1)
s1 = self.classifier2(r1[:,-1,:])
m1 = self.mm(r1[:,-1,:])
m1 = m1.view(-1,2,3)
z2 = self.stn(y,m1)
z2 = z2.view(-1,1,14*14*512)
self.lstm.flatten_parameters()
r2, (h2, c2) = self.lstm(z2,(h1,c1))
s2 = self.classifier2(r2[:,-1,:])
m2 = self.mm(r2[:,-1,:])
m2 = m2.view(-1,2,3)
z3 = self.stn(y,m2)
z3 = z3.view(-1,1,14*14*512)
self.lstm.flatten_parameters()
r3, (h3, c3) = self.lstm(z3,(h2,c2))
s3 = self.classifier2(r3[:,-1,:])
return s1,s2,s3
def st_lstm(num_classes):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ST_LSTM(num_classes=num_classes)
return model