Sliced tensor is copied as a whole when iterating over it for multiprocessing

Just wanted to drop a note if someone stumbles over the same problem.

The scenario: Applying a costly function to each of the elements along the batch dimension via multiprocessing.
If I iterate over the tensor and start multiprocessing the whole tensor is being copied and not only the one element. This can lead to excessive RAM consumption which is unnecessary. Interestingly the behaviour when using numpy is different.
Using .clone() solves the problem.

Please find the example below:

import numpy as np
import pathos.pools
import time
import torch

def slow_fn(x):
    time.sleep(0.1)
    return

imgs = torch.rand(300, 720, 1280)

# (a) pythonic, but RAM explosion
imgs_ = [x for x in imgs]

# (b) still RAM explosion
imgs_ = torch.unbind(imgs)

# (c) okay with pytorch
imgs_ = [x.clone() for x in imgs]

# (d) okay with numpy (which is striking)
imgs_ = [x for x in imgs.numpy()]


pool = pathos.pools.ProcessPool(8)
_ = pool.map(slow_fn, imgs_)
1 Like