I use the torch.utils.data.DataLoader
to iterate through the training set. And the the program froze after it loads a batch of data. here is my code for the DataSet Class:
class DeconvDataSet(Dataset):
def __init__(self, gt_dir, tr_dir, start, length):
self.gt_dir=gt_dir #directory for groundtruth
self.tr_dir=tr_dir #directory for training set
self.start=start #start index(the directory contains far more samples than I want to use)
self.length=length #data set length
def __len__(self):
return self.length
def __getitem__(self, idx):
gtFile_path=os.path.join(self.gt_dir,'gt%d.npy'%(idx+self.start))
trFile_path=os.path.join(self.tr_dir,'tr%d.npy'%(idx+self.start))
gt=np.load(gtFile_path) #each sample is a 3D numpy array
tr=np.load(trFile_path)
print('load%d'%idx)
print(tr.shape)
return {'gt': torch.from_numpy(gt), 'tr': torch.from_numpy(tr)}
this is where I use the dataLoader
dataLoader=DataLoader(dataSet, batch_size=8, shuffle=True, num_workers=1)
print('start training...')
for epoch in range(2):
print('start epoch %d' %epoch)
for i_batch, sample in enumerate(dataLoader):
print('read the data')
input,target=sample['tr'].type(torch.FloatTensor), sample['gt'].type(torch.FloatTensor)
if torch.cuda.is_available():
input, target=input.unsqueeze(1).cuda(), target.unsqueeze(1).cuda()
else:
input, target=input.unsqueeze(1), target.unsqueeze(1)
input, target=Variable(input), Variable(target)
#feed data into the net
optimizer.zero_grad()
print('put the data into net')
output=net(input)
#define loss function
loss = criterion(output, target)
loss = loss*1000
print('back propagate')
loss.backward()
optimizer.step()
print('iter %d, mse %.3f' %(i_batch, loss))
and finally, I got the output like this:
loading dataset
start training…
start epoch 0
load656
(30, 200, 100)
load156
(30, 200, 100)
load800
(30, 200, 100)
load847
(30, 200, 100)
load807
(30, 200, 100)
load299
(30, 200, 100)
load415
(30, 200, 100)
load33
(30, 200, 100)
and stuck here.
By the way, I use it in a docker
when I ps aux
it seems I have a high virtual memory usage
USER PID %CPU %MEM VSZ RSS TTY STAT START TIME COMMAND
zhoutk 39451 1.7 0.9 75496188 1241632 pts/0 Sl+ 17:33 0:08 python deconv.py