Hello.
I am writing codes to train neural networks.
When training with a small size dataset, there’s no problem, however, when training with a large dataset, the system says “RuntimeError: CUDA error: out of memory”.
In my understanding, GPU memory use isn’t influenced by the size of the dataset since Pytorch load and store data for each iteration using indices. So, why is this happening?
This is the dataloader code.
import torch.utils.data
from preprocessing import *
from torchvision import transforms
from config import config
import torch.nn.functional as F
from scipy import ndimage
import numpy as np
from os.path import join
from PIL import Image
from plyfile import PlyData
pre=Prepro()
class PRDataset(torch.utils.data.Dataset):
def __init__(self,tem_dir,ref_dir, visual=False, transform=None):
self.ref_dir = ref_dir
self.tem_dir = tem_dir
self.transform=transform
self.ref_names = os.listdir(ref_dir)
self.ref_names.sort(key=self.natural_keys)
self.tem_names = os.listdir(tem_dir)
self.tem_names = [self.tem_names[0] for i in range(len(self.ref_names))]
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.trans_in = transforms.Compose([
transforms.ToTensor(),
self.normalize])
self.trans_out = transforms.Compose([
transforms.ToTensor()])
self.visual=visual
def atoi(self,text):
return int(text) if text.isdigit() else text
def natural_keys(self,text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [self.atoi(c) for c in re.split('(\d+)', text) ]
def read_ply(self,file_path):
plydata = PlyData.read(file_path).elements[0].data#[:,0:3]
return np.array(plydata.tolist())[:,0:3]
def __len__(self):
return len(self.ref_names)
def __getitem__(self, idx):
tem_name = '{}'.format(self.tem_names[idx])
tem_fullname = join(self.tem_dir, tem_name)
ref_name = '{}'.format(self.ref_names[idx])
ref_fullname = join(self.ref_dir, ref_name)
tem = self.read_ply(tem_fullname)
ref = self.read_ply(ref_fullname)
if self.visual:
return [tem, ref, ref_name]
else:
return [tem, ref]
Also this is the training code.
from model import *
from save_ply import *
import torch.utils.data
from dataloader import *
import time
from datasets import *
import utils
from loss_functions import Loss_Funcs
from config import config
from tensorboardX import SummaryWriter
from loss_functions import Loss_Funcs
os.environ["CUDA_VISIBLE_DEVICES"]="3"
use_gpu=torch.cuda.is_available()
prepro=Prepro()
save=Save()
lossf = Loss_Funcs()
path_dic={
"t_T":'../../dataset/surface/train/template/',
"t_V": '../../dataset/surface/val/template/',
"r_T": '../../dataset/surface/train/reference/',
"r_V": '../../dataset/surface/val/reference/'
}
"""
Data distribution
"""
train = PRDataset(path_dic["t_T"],path_dic["r_T"], transform=True)
val = PRDataset(path_dic["t_V"],path_dic["r_V"],transform=True)
train_loader = torch.utils.data.DataLoader(train, batch_size=config.batch, shuffle=True,num_workers=4)
valid_loader = torch.utils.data.DataLoader(val, batch_size=config.batch, shuffle=True,num_workers=4)
dloaders={"train":train_loader,"valid":valid_loader}
train_num = train.__len__()
val_num = val.__len__()
print
print '<loaded data>'
print 'all data :',train_num+val_num
print 'train data :',train_num
print 'validation data :',val_num
def G(G_models,tmp,ref):
n_pt = tmp.shape[1]
out1_1= G_models[0](torch.transpose(tmp,1,2))
out1_2 = G_models[1](out1_1)
out1_2 = torch.max(out1_2, 2, keepdim=True)[0]
out2_1 = G_models[2](torch.transpose(ref,1,2))
out2_2 = G_models[3](out2_1)
out2_2 = torch.max(out2_2, 2, keepdim=True)[0]
out1_2 = out1_2.squeeze().repeat(n_pt, 1)
out2_2 = out2_2.squeeze().repeat(n_pt, 1)
out1_1 = torch.transpose(out1_1,1,2).squeeze()
out_con = torch.transpose(torch.cat((out1_1, out1_2, out2_2), 1),0,1)
out3_1 = torch.transpose(G_models[4](out_con.view(1,-1,n_pt)),1,2)
return tmp + out3_1
def D(D_models, input):
out_d1 = D_models[0](torch.transpose(input,1,2))
out_d2 = D_models[1](out_d1)
out_d2 = torch.transpose(out_d2, 0, 1)
out_d2 = torch.max(out_d2, 2, keepdim=True)[0]
d = D_models[2](out_d2.view(1,-1)).squeeze()
return out_d2, d
"""
Training
"""
def train_model(criterionGAN,G_models,D_models, dataloaders,num_epoch,G_opt,D_opt):
print
print '<training start>'
writer_grad = SummaryWriter('./log/epoch_'+str(num_epoch)+'/grad')
for epoch in range(num_epoch):
since = time.time()
running_train_Gloss = 0.0
running_train_p_loss = 0.0
running_train_Dloss = 0.0
running_train_G_totalloss = 0.0
running_valid_G_totalloss=0.0
running_valid_Gloss = 0.0
running_valid_p_loss = 0.0
running_valid_Dloss = 0.0
#for phase in ["train", "valid"]:
for phase in ["train"]:
if phase == "train":
for model in G_models: model.train(True)
for model in D_models: model.train(True)
else:
for model in G_models: model.train(False)
for model in D_models: model.train(False)
for i, (tmp,ref) in enumerate(dataloaders[phase]):
if use_gpu:
tmp = Variable(tmp.cuda()).float()
ref = Variable(ref.cuda()).float()
one = Variable(torch.ones(tmp.shape[0]), requires_grad=False).cuda()
zero = Variable(torch.zeros(tmp.shape[0]), requires_grad=False).cuda()
else:
tmp = Variable(tmp).float()
ref = Variable(ref).float()
one = Variable(torch.ones(tmp.shape[0]), requires_grad=False)
zero = Variable(torch.zeros(tmp.shape[0]), requires_grad=False)
############################
# (1) Update D network
###########################
for opt in D_opt: opt.zero_grad()
#for opt in G_opt: opt.zero_grad()
G_out = G(G_models,tmp,ref)
_,d_fake = D(D_models,G_out)
D_fake_loss = criterionGAN(d_fake,zero)
_,d_real = D(D_models,ref)
D_real_loss = criterionGAN(d_real,one)
D_adv_loss = (D_fake_loss + D_real_loss) * 0.5
if phase == "train":
D_adv_loss.backward(retain_graph=True)
#for p in D_models[1].parameters():
#print p.grad
for opt in D_opt: opt.step()
############################
# Update G network
###########################
#for opt in D_opt: opt.zero_grad()
for opt in G_opt: opt.zero_grad()
feature_fake, d_fake2= D(D_models,G_out)
feature_real, _ = D(D_models,ref)
G_adv_loss = criterionGAN(d_fake2, one)
p_loss = config.p_weight*torch.norm(feature_fake - feature_real, p=2)
G_loss = G_adv_loss + p_loss
if phase == "train":
G_loss.backward(retain_graph=True)
for opt in G_opt: opt.step()
if phase=="train":
# accumulate the loss
running_train_Gloss += G_adv_loss.item()
running_train_p_loss += p_loss.item()
running_train_G_totalloss += G_adv_loss.item() + p_loss.item()
running_train_Dloss += D_adv_loss
else:
running_valid_Gloss += G_adv_loss.item()
running_valid_p_loss += p_loss.item()
running_valid_G_totalloss += G_adv_loss.item() + p_loss.item()
running_valid_Dloss += D_adv_loss
if epoch!= 0:
epoch_time=time.time()-since
print "Time:{:f} Epoch [{}/{}] Train G_Total: {:.4f} (G_p: {:.4f} G_adv: {:.4f}) D : {:.4f} |Val G_Total: {:.4f} (G_p: {:.4f} G_adv: {:.4f}) D : {:.4f}".format(
epoch_time,
epoch,
num_epoch,
float(running_train_G_totalloss)/train_num,#8,
float(running_train_p_loss) / train_num, # 8,
float(running_train_Gloss) / train_num, # 8,
float(running_train_Dloss) / train_num, # 8,
float(running_valid_G_totalloss)/val_num,#8,
float(running_valid_p_loss) / val_num, # 8,
float(running_train_Gloss) / val_num, # 8,
float(running_valid_Dloss) / val_num, # 8,
)
if epoch%10==0:
# save trained model
for i,model in enumerate(G_models): torch.save(model.state_dict(), "trained_model/TMP_net"+str(i+1)+"_" + str(epoch) + ".pkl")
for i,model in enumerate(D_models): torch.save(model.state_dict(), "trained_model/TMP_net"+str(i+6)+"_" + str(epoch) + ".pkl")
writer_grad.close()
print
print('<Finished Training>')
return
if config.use_gpu:
TMP_net1=TMP_net1().cuda()
TMP_net2=TMP_net2().cuda()
TMP_net3=TMP_net3().cuda()
TMP_net4=TMP_net4().cuda()
TMP_net5=TMP_net5().cuda()
TMP_net6=TMP_net6().cuda()
TMP_net7=TMP_net7().cuda()
TMP_net8=TMP_net8().cuda()
else:
TMP_net1=TMP_net1()
TMP_net2=TMP_net2()
TMP_net3=TMP_net3()
TMP_net4=TMP_net4()
TMP_net5=TMP_net5()
TMP_net6=TMP_net6()
TMP_net7=TMP_net7()
TMP_net8=TMP_net8()
lr = 0.0001
TMP_net1_optimizer = torch.optim.Adam(TMP_net1.parameters(), lr=lr)
TMP_net2_optimizer = torch.optim.Adam(TMP_net2.parameters(), lr=lr)
TMP_net3_optimizer = torch.optim.Adam(TMP_net3.parameters(), lr=lr)
TMP_net4_optimizer = torch.optim.Adam(TMP_net4.parameters(), lr=lr)
TMP_net5_optimizer = torch.optim.Adam(TMP_net5.parameters(), lr=lr)
TMP_net6_optimizer = torch.optim.Adam(TMP_net6.parameters(), lr=lr)
TMP_net7_optimizer = torch.optim.Adam(TMP_net7.parameters(), lr=lr)
TMP_net8_optimizer = torch.optim.Adam(TMP_net8.parameters(), lr=lr)
G_models = [TMP_net1,TMP_net2,TMP_net3,TMP_net4,TMP_net5]
D_models = [TMP_net6,TMP_net7,TMP_net8]
G_opt = [TMP_net1_optimizer,TMP_net2_optimizer,TMP_net3_optimizer,TMP_net4_optimizer,TMP_net5_optimizer]
D_opt = [TMP_net6_optimizer,TMP_net7_optimizer,TMP_net8_optimizer]
utils.count_model_params(G_models)
utils.count_model_params(D_models)
criterionGAN = nn.BCELoss()
criterionMSE = nn.MSELoss()
start_time=time.time()
model=train_model(criterionGAN,G_models,D_models,dloaders,num_epoch=config.num_epoch,G_opt=G_opt,D_opt=D_opt)