Does select and narrow return a view or copy

Hi,

I am extracting matrices (patches) from a tensor by using select and narrow. I would like to know if those operations are returning view to the same underlying storage or copy.

This question is related to a more complex process where :
1/ all patches positions are generated when my DataLoader is called
2/ each patch is extracted from the tensor only on __getitem__ call

Generating patches positions is not really heavy on memory, but extracting is such an expensive operation, it grows linearly and eventually blows 32 GB of RAM.

I would like to understand what is done wrong. To give you some idea, here is the logic behind the code:

class PatchExtractor(data.Dataset):
    def __init__(self, root, patch_size, transform=None,
                target_transform=None):
        self.root = root
        self.patch_size = patch_size
        self.transform = transform
        self.target_transform = target_transform
        # extract all patch positions
        self.dataset = make_dataset(root,
                                    patch_size)
        if loader is None:
            self.loader = Loader()
        else:
            self.loader = loader

    def __getitem__(self, index):
        path, args, target = self.dataset[index]

        # img is a tensor returned by a succession of narrow and select
        img = self.loader.load(path, args, self.patch_size)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.dataset)

I don’t think I am mistaking on the DataLoader. self.loader.load returns precisely this kind of results :

self.images[path].select(0, position[0]) \
                        .narrow(0, y-border_width, patch_size) \
                        .narrow(1, z-border_width, patch_size)

From my opinion, it seems everytime a patch is loaded and transfered to cuda, it is not freed after usage. I don’t store those values, I use the same train loop as in the imagenet example. The fact that a large bunch of my memory is being freed right after an epoch ends (ie when testing starts) guides me to that observation.

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(async=True)
        input_25 = torch.autograd.Variable(input[0]).cuda()
        input_51 = torch.autograd.Variable(input[1]).cuda()
        input_75 = torch.autograd.Variable(input[2]).cuda()

        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(patch25=input_25, patch51=input_51, patch75=input_75)
        loss = criterion(output, target_var)

        # debug loss value
        # print('raw loss is {loss.data[0]:.5f}\t'.format(loss=loss))

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input[0].size(0))
        top1.update(prec1[0], input[0].size(0))
        top5.update(prec5[0], input[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))

Should I del my variables right after backprop ?

Thank you for the feedback

1 Like

select, narrow and indexing operations (except when using a LongTensor index) return views onto the same memory. However, the transforms operate out-of-place, so afterwards, you’re going to get a copy. Again, cating views also needs to allocate a new tensor for the output. Of course the outputs of __getindex__ should be out of scope a moment later, so the memory pressure stays constant.

Is it the CPU memory that’s blowing up or the GPU? Are you caching anything inside the model? Did you try adding gc.collect() calls inside the loop (not recommended, but if that helps it means that you have reference cycles somewhere)? deling the Variables shouldn’t be necessary.

1 Like

As a side note, it’s faster to transfer whole input to the GPU at once, and create Variables with the slices afterwards. And unless target is in pinned memory, async=True is a no-op.

2 Likes

Also, if you’re using DataLoader with many workers, you might want to use a lower number. Each worker probably has its own copy of the dataset in memory and the patches it extracts are getting accumulated in the queue, because it’s a fast operation.

3 Likes

It’s the CPU memory that is blowing, the GPU memory stays stable.

I am not caching anything inside the model, only torch.nn ops. I will try gc.collect() and let you know. Thanks for the advice, I am not sure I understood the pinned memory feature, can it be related to my problem somehow ?

To give you an idea, in test mode I am around 9 GB used, in training mode it’s around 20 GB. The thing is that memory usage in test mode is stable between 8-9 GB, but during training the RAM is slowly eatean, batcth after batch, constantly growing until the next test step where it goes back to 8-9 GB.

Does deling the Variables or gc.collect() help?

I’m back @apaszke . No del and gc.collect() did not help.

However reducing the number of workers from 8 to 2 did reduce the memory footprint by half, at the cost of increasing data loading time. Still, even when lowering the number of workers, the memory usage keeps growing (slower) from batch to batch, until it reaches a stable point. I guess you have implemented some kind of maximum caching for the queue, and the limit is satisfied. This is both a good point to see it constant and sad because it shows a problem I have no idea how to solve. Reducing the batch size should help also right ? as it is prepared on the CPU side, I am currently at 4096 with 3 inputs of 75x75, 51x51, 25x25.

He is my current train loop :

for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(async=True)
        input_25 = torch.autograd.Variable(input[0]).cuda()
        input_51 = torch.autograd.Variable(input[1]).cuda()
        input_75 = torch.autograd.Variable(input[2]).cuda()

        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(patch25=input_25, patch51=input_51, patch75=input_75)
        loss = criterion(output, target_var)

        # debug loss value
        # print('raw loss is {loss.data[0]:.5f}\t'.format(loss=loss))

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input[0].size(0))
        top1.update(prec1[0], input[0].size(0))
        top5.update(prec5[0], input[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))

        del input_25
        del input_51
        del input_75
        del input
        gc.collect()

I have monitored instantiating the DataLoader weights around 4.2 GB in CPU memory, this is were all patches positions are extracted and stored. This is just a list of 15 millions tuples.

Does it mean I will cost around 4.2x4 if I have 4 processes building the batches ?

Yes, exactly. We have an upper bound on the queue size, so that if the workers are fast, they won’t fill up the memory, and it’s scaled linearly with each worker. If you find yourself with 1 or 2 workers saturating the queue there’s no point in using more of them. Also, as I said, remember that if you’re using an in-memory dataset, each worker is likely to keep its own copy, so the memory usage will be quite high. You could try to load the images lazily or use some kind of an in-memory database, that all workers will contact for the data.

2 Likes

There is no way for each worker to have its unique set of data, so that each one does not need a full copy of the original ?
I mean splitting the dataset between every worker ?

Also what I the reason the validation mode is way more efficient on CPU memory than training, because on CPU side, this is basically the same to me, I mean the data processing is the same, but it uses 10 GB in RAM.

only differences are :

# instead of model.train()
model.eval()

# using volatile
input_25 = torch.autograd.Variable(input[0], volatile=True).cuda()

No, there’s no easy way to do it. DataLoader isn’t meant to load from multiple splits, but from a single dataset. One possible solution would be to call .share_memory_() on the tensor that holds the images. This way, when you fork the workers, they should inherit it (unless something I don’t remember about stops them).

Can you show me how you instantiate both training and validation data loaders?

Yes sure. I have focused my work on reducing the DataLoader footprint. I have found using a list of int was really really expensive : python ints are 28 bytes long. By using a numpy array I can reduce my RAM footprint from 1 GB to 80 MB. I will see if it scales well on a server.

    train_loader = torch.utils.data.DataLoader(
        medfolder.MedFolder(traindir, file_extensions, patch_size, label_map, file_map,
                  transform=Compose([
                       transforms.CenterSquareCrops([25, 51, 75]),
                       transforms.Unsqueeze(0)
                    ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        medfolder.MedFolder(valdir, file_extensions, patch_size, label_map, file_map,
                  transform=Compose([
                       transforms.CenterSquareCrops([25, 51, 75]),
                       transforms.Unsqueeze(0)
                    ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

The only difference is that the validation dataset is not shuffled. I will give you some updates on this tomorrow. I think I am nearing the end of this problem.

Really amazed by the community support and your work !! :slight_smile:
Thanks !

Ugh, nice finding! That should probably do it. Note that PyTorch tensors also keep the data packed, so you can use them instead of numpy indices too.

Just be careful with pinned memory when it’s already under a high pressure. Pinned memory is never swapped out, so it can possibly freeze or crash the system if there’s too much of it.

Thanks! :slight_smile:

1 Like

I found that the computation became slower and slower when training a huge volume of data with cuda. The batch size is 1003224*224. At the very beginning, it costs 0.4 second to train each chuck. Little by little, it costs nearly 3 seconds. In contrast, if selecting a small group of data from the original dataset for training, the speed keeps the same all along. Therefore, I wonder how to train a model on a large dataset. The main codes are simplified as follows to make the codes more readability.

x= torch.autograd.Variable(torch.zeros(batchSize,3,2*imgW,2*imgW)).cuda(0) 
yt=torch.autograd.Variable(torch.LongTensor(batchSize).zero_()).cuda(0)  
model = models.resnet18(pretrained=True).cuda(0) 
for epoch in range(0,epochs): 
    correct=0
        
    for t in range(0, trainSize, batchSize):  
        idx=0 
        for i in range(t, maxSize): # the batch of input x and target yt are assigned by local patches of # a large image here
             x[idx,:,:,:]=image[:, i-imgW:i+imgW, i-imgW:i+imgW] 
             yt[idx]=torch.from_numpy(Class[i]) 
             idx=idx+1    
        optimizer.zero_grad()   
        output = model(x) 
        loss = criterion(m(output), yt) 
        pred = output.data.max(1)[1] 
        correct += pred.eq(yt.data).sum()
        loss.backward()
        optimizer.step()

Thank you for your instruction beforehand.

You should never reuse Variables indefinitely like you do with x and y. They keep track of all operations you do on them, including assignments! Every execution of the for i in range(t, maxSize): loop will make the history of x and y longer by a single assignment, so the graphs will get huge very quickly. A fix is to move the declarations of x and yt inside the inner loop like that:


model = models.resnet18(pretrained=True).cuda(0) 
for epoch in range(0,epochs): 
    correct=0
        
    for t in range(0, trainSize, batchSize):  
        idx=0 
        x_data = []
        y_data = []
        for i in range(t, maxSize): # the batch of input x and target yt are assigned by local patches of # a large image here
             x[idx,:,:,:]=image[:, i-imgW:i+imgW, i-imgW:i+imgW] 
             yt[idx]=torch.from_numpy(Class[i]) 
             x_data.append(image[:, i-imgW:i+imgW, i-imgW:i+imgW])
             y_data.append(torch.from_numpy(Class[i]))
             idx=idx+1    
        x= torch.autograd.Variable(torch.stack(x_data, 0).cuda(0))
        yt=torch.autograd.Variable(torch.cat(y_data, 0).cuda(0))
        optimizer.zero_grad()   
        output = model(x) 
        loss = criterion(m(output), yt) 
        pred = output.data.max(1)[1] 
        correct += pred.eq(yt.data).sum()
        loss.backward()
        optimizer.step()
1 Like

Got it! Thank you very much!

I have read the related examples, and noticed that Variable is defined in the loop. I did not know it is necessary to do like this.

Thank you, Adam, for your instruction again!