Memory pollution while concurrent data transfer with multi cuda stream

Hello community.

I’m trying to load and offload the model’s weight from the host(pinned) and GPU and from the GPU and host(pinned), with some multiple Python threads and CUDA streams.

But, I’ve got different inference results when transferring the model’s weights(actually layers you know) when migrate concurrently and not concurrently.

Also, I presume the problem is on difference between DeviceCachingAllocator 's free()
and the CUDAHostAllocator’s free()

CUDAHostAllocator inserts the event to cuda_events_ when it’s not ready.
But, DeviceCachingAllocator doesn’t insert the event, even if it’s not ready.

Are these behaviors expected?

Here are some codes for simulating the problem. I know it’s…a little bit long…but you can reproduce the problem easily.
and also, I’m using PyTorch 2.0.1

from queue import Queue
import threading
import torch

N_MEMORY = 10
MEMORY_LIMIT = 3 * 1024**3  # 3GB
LAYER_SIZE = 512 * 1024**2  # 512MB


class Object(object):
    pass


class FakeLoader(threading.Thread):
    def __init__(self, streams):
        threading.Thread.__init__(self)
        self.streams = streams
        self.memory_loaded = 0
        self.loaded_layer = Queue()

    def can_load(self):
        global layers

        if self.memory_loaded + LAYER_SIZE < MEMORY_LIMIT:
            return True
        else:
            return False

    def load(self, id):
        global layers

        print("BEGIN LOAD", id)
        with torch.cuda.stream(self.streams["upstream"]):
            layers[id].tensor = layers[id].tensor.to("cuda:0", non_blocking=True)

        layers[id].end_load_event = torch.cuda.Event(enable_timing=True)
        layers[id].end_load_event.record(self.streams["upstream"])

        self.memory_loaded += LAYER_SIZE
        self.loaded_layer.put(id)

    def unload(self, id):
        global layers

        print("BEGIN UNLOAD", id)
        with torch.cuda.stream(self.streams["downstream"]):
            layers[id].tensor = layers[id].tensor.to("cpu", non_blocking=True)
        # !!! UNCOMMENT if you want precise result
        # self.streams["downstream"].synchronize()

        self.memory_loaded -= LAYER_SIZE

    def run(self):
        global layers

        for id in range(N_MEMORY):
            layers[id].begin_load_event = torch.cuda.Event(enable_timing=True)
            layers[id].begin_load_event.record(self.streams["upstream"])

            while True:
                if self.can_load():
                    self.load(id)

                    break
                else:
                    self.unload(self.loaded_layer.get())
            layers[id].load_lock.set()


if __name__ == "__main__":
    global layers
    layers = {}
    orginal_memorys_cpu = {}
    orginal_memorys_gpu = {}
    for id in range(N_MEMORY):
        layers[id] = Object()
        layers[id].tensor = torch.randint(
            32000, [1, 64 * 1024**2], dtype=torch.long
        )  # 512MB
        layers[id].tensor = layers[id].tensor.pin_memory()

        orginal_memorys_cpu[id] = layers[id].tensor.detach().clone().pin_memory()
        orginal_memorys_gpu[id] = layers[id].tensor.detach().clone().to("cuda:0")

    streams = {}
    streams["upstream"] = torch.cuda.Stream()
    streams["downstream"] = torch.cuda.Stream()

    # init()
    for layer in layers.values():
        layer.load_lock = threading.Event()

    loader = FakeLoader(streams)
    loader.start()
    loader.join()

    # check_result()
    for id in range(N_MEMORY):
        if layers[id].tensor.device == torch.device("cpu"):
            print("Layer", id, "is in cpu")

            if torch.allclose(
                layers[id].tensor, orginal_memorys_cpu[id], equal_nan=True
            ):
                print(f"Layer {id} is same")
            else:
                print(f"Layer {id} is different")
        elif layers[id].tensor.device == torch.device("cuda:0"):
            print("Layer", id, "is in cuda")

            if torch.allclose(
                layers[id].tensor, orginal_memorys_gpu[id], equal_nan=True
            ):
                print(f"Layer {id} is same")
            else:
                print(f"Layer {id} is different")
        else:
            print("!!! SOMETHING WRONG Device is", layers[id].tensor.device)

Thanks.

Yes, this is expected since you need to synchronize the stream before being able to use the CPUTensor after moving it with non_blocking=True.
This example illustrates it:

import torch
import time
from statistics import mean

def get_cycles_per_ms() -> float:
    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
    """

    def measure() -> float:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        torch.cuda._sleep(1000000)
        end.record()
        end.synchronize()
        cycles_per_ms = 1000000 / start.elapsed_time(end)
        return cycles_per_ms

    # Get 10 values and remove the 2 max and 2 min and return the avg.
    # This is to avoid system disturbance that skew the results, e.g.
    # the very first cuda call likely does a bunch of init, which takes
    # much longer than subsequent calls.
    #
    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
    # and seems to return stable values.
    num = 10
    vals = []
    for _ in range(num):
        vals.append(measure())
    vals = sorted(vals)
    return mean(vals[2 : num - 2])


if __name__ == '__main__':
    seed = 0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    cycles = get_cycles_per_ms()
    
    stream = torch.cuda.current_stream()

    x = torch.rand(32, 256, 220, 220, device="cuda")
    torch.cuda._sleep(int(100 * cycles))
    t = x.to(torch.device("cpu"), non_blocking=True)
    print(stream.query()) # False - work not done yet
    #stream.synchronize() # wait for stream to finish the work
    print(t.abs().sum())
    
    time.sleep(2.)
    print(stream.query()) # True - work done
    print(t.abs().sum())

You should see two different outputs since the needed stream.synchronize() is commented out.
Comment it in and you should see the same results again.

1 Like

Thanks for your quick and kind response :slight_smile:

I’ve got a question Why do they have a different approach between CPUTensor and CUDATensor ?
For example, in code snippets that you provided, change "cuda" to "cpu"(torch.rand()), and "cpu" to "cuda".(Tensor.to()), then the result is True and True.

This behavior on CPUTensor is for reducing latency or something else?

CUDA operations will be executed in the surrounding CUDAStream, which is the single default stream in my example. If you move a CPUTensor to the GPU asynchronously via non_blocking=True, the operation will be non-blocking w.r.t. the host, however will be performed in the surrounding stream and is thus in order. If you want async execution w.r.t. the GPU you would need to use different streams.

1 Like

Thanks! . And it will be my last question. Why do they implemented different ways between DeviceHostAllocator and CUDACaching allocator?

The CUDAHostAllocator is responsible to allocate pinned (page-locked) memory on the host via cudaHostAlloc while the CUDACachingAllocator is responsible to allcoate and free memory on the GPU.

1 Like

My actual question is why in at::native::copy_kernel_cuda() doesn’t call CUDACachingAllocator.recordStream(), but does call CUDAHostAllocator.record_event().
I think CUDACachingAllocator.recordStream() is called for synchronize, then it should call CUDAHostAllocator.record_event() too.

I think I might misunderstand something…