Pytorch memory explodes on my DL rig compared to my local development machine

I am currently trying to develop an application. I am developing this application on my mac with the following specs:

macOS High Sierra, 10.13.6
Python version 3.6.5
Pytorch 0.4.1
8 GB RAM
no CUDA-GPU

I am then uploading this application to my machine learning rig, with the following specs:

Ubuntu 16.04.5 LTS
Python version 3.5
Pytorch 0.4.1
32 GB RAM
GTX 1060 with 6GB of memory
NVIDIA-SMI (Driver version) 390.87
CUDA Version 9.0.176

CUDNN_MAJOR 7
CUDNN_MINOR 2

when I apply my application on my mac, the peak memory usage is at about 2GB.
However, when I apply my application on my Ubuntu, the peak memory shows to be about 14GB.
For both cases, it is the entire python application.

Specifically, I have the following code snippet, which is part of the training
I am currently trying to develop an application. I am developing this application on my mac with the following specs:

macOS High Sierra, 10.13.6
Python version 3.6.5
Pytorch 0.4.1
8 GB RAM
no CUDA-GPU

I am then uploading this application to my machine learning rig, with the following specs:

Ubuntu 16.04.5 LTS
Python version 3.5
Pytorch 0.4.1
32 GB RAM
GTX 1060 with 6GB of memory
NVIDIA-SMI (Driver version) 390.87
CUDA Version 9.0.176

CUDNN_MAJOR 7
CUDNN_MINOR 2

when I apply my application on my mac, the peak memory usage is at about 2GB.
However, when I apply my application on my Ubuntu, the peak memory shows to be about 14GB.
For both cases, it is the entire python application.

Specifically, I have the following code snippet, which is part of the training method:

       for train_idx in range(0, data_size, ARG.batch_size):

            print("Memory usage L1: ", memory_usage_resource())

            X_cur = X[train_idx:train_idx + ARG.batch_size, :]
            Y_cur = Y[train_idx:train_idx + ARG.batch_size, :]

            print("Memory usage L2: ", memory_usage_resource())

            Y_hat = self.model.forward(X_cur)

            print("Memory usage L3: ", memory_usage_resource())

            Y_hat = Y_hat.transpose(1, -1).contiguous()
            Y_hat = Y_hat.transpose(2, -1).contiguous()

            Y_cur = Y_cur.transpose(1, -1).contiguous()
            Y_cur = Y_cur.transpose(2, -1).contiguous()
            Y_cur = Y_cur.squeeze()

            print("Memory usage L4: ", memory_usage_resource())

            del X_cur
            gc.collect()
            torch.cuda.empty_cache()

            loss = self.criterion(Y_hat, Y_cur)

            del Y_hat
            del Y_cur
            gc.collect()
            torch.cuda.empty_cache()

            print("Memory usage L5: ", memory_usage_resource())

            loss.backward()

            print("Memory usage L6: ", memory_usage_resource())

            # Clip gradients here
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), ARG.shared_grad_clip)
            self.optimizer.step()

            print("Memory usage L7: ", memory_usage_resource())

            losses[train_idx // ARG.batch_size] = loss / ARG.batch_size

            tx_counter[0] += 1
            tx_writer.add_scalar('loss/train_loss', loss / ARG.batch_size, tx_counter[0])

            del loss
            gc.collect()
            torch.cuda.empty_cache()

            print("Memory usage L8: ", memory_usage_resource())

            sys.stdout.write("-")
            sys.stdout.flush()

This code produces the following output.

On my mac, this is the output (All output is in MB):

Total memory used (MB):  237.5859375
Memory usage P1:  237.5859375
Training size is:  torch.Size([400, 11, 1])  from  torch.Size([51643, 11, 1])
Memory usage P1:  237.5859375
Memory usage L1:  237.59375
Memory usage L2:  237.59765625
Memory usage L3:  487.765625
Memory usage L4:  494.48046875
Memory usage L5:  494.48046875
Memory usage L6:  569.96875
Memory usage L7:  569.96875
Memory usage L8:  569.96875
-Memory usage L1:  569.96875
Memory usage L2:  569.96875
Memory usage L3:  569.96875
Memory usage L4:  569.96875
Memory usage L5:  569.96875
Memory usage L6:  627.88671875
Memory usage L7:  627.88671875
Memory usage L8:  627.88671875
-Memory usage L1:  627.88671875
Memory usage L2:  627.88671875
Memory usage L3:  627.88671875
Memory usage L4:  627.88671875
Memory usage L5:  627.88671875
Memory usage L6:  642.37109375
Memory usage L7:  642.37109375
Memory usage L8:  642.37109375
-Memory usage L1:  642.37109375
Memory usage L2:  642.37109375
Memory usage L3:  642.37109375
Memory usage L4:  642.37109375
Memory usage L5:  642.37109375
Memory usage L6:  646.1640625
Memory usage L7:  646.1640625
Memory usage L8:  646.1640625
-Memory usage L1:  646.1640625
Memory usage L2:  646.1640625
Memory usage L3:  646.1640625
Memory usage L4:  646.1640625
Memory usage L5:  646.1640625
Memory usage L6:  662.26953125
Memory usage L7:  662.26953125
Memory usage L8:  662.26953125
-Memory usage L1:  662.26953125
Memory usage L2:  662.26953125
Memory usage L3:  662.26953125
Memory usage L4:  662.26953125
Memory usage L5:  662.26953125
Memory usage L6:  667.3984375
Memory usage L7:  667.3984375
Memory usage L8:  667.3984375
-Memory usage L1:  667.3984375
Memory usage L2:  667.3984375
Memory usage L3:  667.3984375
Memory usage L4:  667.3984375
Memory usage L5:  667.3984375
Memory usage L6:  677.87109375
Memory usage L7:  677.87109375
Memory usage L8:  677.87109375
-Memory usage L1:  677.87109375
Memory usage L2:  677.87109375
Memory usage L3:  677.87109375
Memory usage L4:  677.87109375
Memory usage L5:  677.87109375
Memory usage L6:  690.046875
Memory usage L7:  690.046875
Memory usage L8:  690.046875
-Memory usage L1:  690.046875
Memory usage L2:  690.046875
Memory usage L3:  690.046875
Memory usage L4:  690.046875
Memory usage L5:  690.046875
Memory usage L6:  700.3984375
Memory usage L7:  700.3984375
Memory usage L8:  700.3984375
-Memory usage L1:  700.3984375
Memory usage L2:  700.3984375
Memory usage L3:  700.3984375
Memory usage L4:  700.3984375
Memory usage L5:  700.3984375
Memory usage L6:  711.39453125
Memory usage L7:  711.39453125
Memory usage L8:  711.39453125
-Memory usage L1:  711.39453125
Memory usage L2:  711.39453125
Memory usage L3:  711.39453125
Memory usage L4:  711.39453125
Memory usage L5:  711.39453125
Memory usage L6:  723.578125
Memory usage L7:  723.578125
Memory usage L8:  723.578125
-Memory usage L1:  723.578125
Memory usage L2:  723.578125
Memory usage L3:  723.578125
Memory usage L4:  723.578125
Memory usage L5:  723.578125
Memory usage L6:  732.82421875
Memory usage L7:  732.82421875
Memory usage L8:  732.82421875
-Memory usage L1:  732.82421875
Memory usage L2:  732.82421875
Memory usage L3:  732.82421875
Memory usage L4:  732.82421875
Memory usage L5:  732.82421875
Memory usage L6:  745.484375
Memory usage L7:  745.484375
Memory usage L8:  745.484375
-Memory usage L1:  745.484375
Memory usage L2:  745.484375
Memory usage L3:  745.484375
Memory usage L4:  745.484375
Memory usage L5:  745.484375
Memory usage L6:  755.12890625
Memory usage L7:  755.12890625
Memory usage L8:  755.12890625
-Memory usage L1:  755.12890625
Memory usage L2:  755.12890625
Memory usage L3:  755.12890625
Memory usage L4:  755.12890625
Memory usage L5:  755.12890625
Memory usage L6:  768.80078125
Memory usage L7:  768.80078125
Memory usage L8:  768.80078125
-Memory usage L1:  768.80078125
Memory usage L2:  768.80078125
Memory usage L3:  768.80078125
Memory usage L4:  768.80078125
Memory usage L5:  768.80078125
Memory usage L6:  777.39453125
Memory usage L7:  777.39453125
Memory usage L8:  777.39453125
-Memory usage L1:  777.39453125
Memory usage L2:  777.39453125
Memory usage L3:  777.39453125
Memory usage L4:  777.39453125
Memory usage L5:  777.39453125
Memory usage L6:  788.62109375
Memory usage L7:  788.62109375
Memory usage L8:  788.62109375
-Memory usage L1:  788.62109375
Memory usage L2:  788.62109375
Memory usage L3:  788.62109375
Memory usage L4:  788.62109375
Memory usage L5:  788.62109375
Memory usage L6:  788.62109375
Memory usage L7:  788.62109375
Memory usage L8:  788.62109375

Ony my ML-rig, the same code produces the following output, however

Total memory used (MB):  264.73828125
Memory usage P1:  264.73828125
Training size is:  torch.Size([400, 11, 1])  from  torch.Size([51643, 11, 1])
Memory usage P1:  264.73828125
Memory usage L1:  264.73828125
Memory usage L2:  264.73828125
Memory usage L3:  547.86328125
Memory usage L4:  547.86328125
Memory usage L5:  553.84765625
Memory usage L6:  726.59375
Memory usage L7:  726.59375
Memory usage L8:  726.59375
-Memory usage L1:  726.59375
Memory usage L2:  726.59375
Memory usage L3:  899.08203125
Memory usage L4:  899.08203125
Memory usage L5:  899.08203125
Memory usage L6:  989.79296875
Memory usage L7:  989.79296875
Memory usage L8:  989.79296875
-Memory usage L1:  989.79296875
Memory usage L2:  989.79296875
Memory usage L3:  1157.67578125
Memory usage L4:  1157.67578125
Memory usage L5:  1157.67578125
Memory usage L6:  1176.8984375
Memory usage L7:  1176.8984375
Memory usage L8:  1176.8984375
-Memory usage L1:  1176.8984375
Memory usage L2:  1176.8984375
Memory usage L3:  1300.86328125
Memory usage L4:  1300.86328125
Memory usage L5:  1300.86328125
Memory usage L6:  1383.0078125
Memory usage L7:  1383.0078125
Memory usage L8:  1383.0078125
-Memory usage L1:  1383.0078125
Memory usage L2:  1383.0078125
Memory usage L3:  1492.046875
Memory usage L4:  1492.046875
Memory usage L5:  1492.046875
Memory usage L6:  1569.03515625
Memory usage L7:  1569.03515625
Memory usage L8:  1569.03515625
-Memory usage L1:  1569.03515625
Memory usage L2:  1569.03515625
Memory usage L3:  1682.51953125
Memory usage L4:  1682.51953125
Memory usage L5:  1682.51953125
Memory usage L6:  1778.3515625
Memory usage L7:  1778.3515625
Memory usage L8:  1778.3515625
-Memory usage L1:  1778.3515625
Memory usage L2:  1778.3515625
Memory usage L3:  1892.90234375
Memory usage L4:  1892.90234375
Memory usage L5:  1892.90234375
Memory usage L6:  1989.08984375
Memory usage L7:  1989.08984375
Memory usage L8:  1989.08984375
-Memory usage L1:  1989.08984375
Memory usage L2:  1989.08984375
Memory usage L3:  2102.5390625
Memory usage L4:  2102.5390625
Memory usage L5:  2102.5390625
Memory usage L6:  2180.15625
Memory usage L7:  2180.15625
Memory usage L8:  2180.15625
-Memory usage L1:  2180.15625
Memory usage L2:  2180.15625
Memory usage L3:  2293.1171875
Memory usage L4:  2293.1171875
Memory usage L5:  2293.1171875
Memory usage L6:  2380.05078125
Memory usage L7:  2380.05078125
Memory usage L8:  2380.05078125
-Memory usage L1:  2380.05078125
Memory usage L2:  2380.05078125
Memory usage L3:  2503.3203125
Memory usage L4:  2503.3203125
Memory usage L5:  2503.3203125
Memory usage L6:  2581.19140625
Memory usage L7:  2581.19140625
Memory usage L8:  2581.19140625
-Memory usage L1:  2581.19140625
Memory usage L2:  2581.19140625
Memory usage L3:  2694.5
Memory usage L4:  2694.5
Memory usage L5:  2694.5
Memory usage L6:  2781.43359375
Memory usage L7:  2781.43359375
Memory usage L8:  2781.43359375
-Memory usage L1:  2781.43359375
Memory usage L2:  2781.43359375
Memory usage L3:  2904.4765625
Memory usage L4:  2904.4765625
Memory usage L5:  2904.4765625
Memory usage L6:  3006.25
Memory usage L7:  3006.25
Memory usage L8:  3006.25
-Memory usage L1:  3006.25
Memory usage L2:  3006.25
Memory usage L3:  3113.9609375
Memory usage L4:  3113.9609375
Memory usage L5:  3113.9609375
Memory usage L6:  3200.89453125
Memory usage L7:  3200.89453125
Memory usage L8:  3200.89453125
-Memory usage L1:  3200.89453125
Memory usage L2:  3200.89453125
Memory usage L3:  3324.42578125
Memory usage L4:  3324.42578125
Memory usage L5:  3324.42578125
Memory usage L6:  3426.45703125
Memory usage L7:  3426.45703125
Memory usage L8:  3426.45703125
-Memory usage L1:  3426.45703125
Memory usage L2:  3426.45703125
Memory usage L3:  3534.2265625
Memory usage L4:  3534.2265625
Memory usage L5:  3534.2265625
Memory usage L6:  3601.83203125
Memory usage L7:  3601.83203125
Memory usage L8:  3601.83203125
-Memory usage L1:  3601.83203125
Memory usage L2:  3601.83203125
Memory usage L3:  3725.296875
Memory usage L4:  3725.296875
Memory usage L5:  3725.296875
Memory usage L6:  3803.16796875
Memory usage L7:  3803.16796875
Memory usage L8:  3803.16796875
-Memory usage L1:  3803.16796875
Memory usage L2:  3803.16796875
Memory usage L3:  3916.1640625
Memory usage L4:  3916.1640625
Memory usage L5:  3916.1640625
Memory usage L6:  4003.09765625
Memory usage L7:  4003.09765625
Memory usage L8:  4003.09765625
-Memory usage L1:  4003.09765625
Memory usage L2:  4003.09765625
Memory usage L3:  4126.5390625
Memory usage L4:  4126.5390625
Memory usage L5:  4126.5390625
Memory usage L6:  4228.5703125
Memory usage L7:  4228.5703125
Memory usage L8:  4228.5703125
-Memory usage L1:  4228.5703125
Memory usage L2:  4228.5703125
Memory usage L3:  4336.078125
Memory usage L4:  4336.078125
Memory usage L5:  4336.078125
Memory usage L6:  4423.01171875
Memory usage L7:  4423.01171875
Memory usage L8:  4423.01171875
-Memory usage L1:  4423.01171875
Memory usage L2:  4423.01171875
Memory usage L3:  4546.75
Memory usage L4:  4546.75
Memory usage L5:  4546.75
Memory usage L6:  4648.5234375
Memory usage L7:  4648.5234375
Memory usage L8:  4648.5234375
-Memory usage L1:  4648.5234375
Memory usage L2:  4648.5234375
Memory usage L3:  4756.15234375
Memory usage L4:  4756.15234375
Memory usage L5:  4756.15234375
Memory usage L6:  4824.015625
Memory usage L7:  4824.015625
Memory usage L8:  4824.015625
-Memory usage L1:  4824.015625
Memory usage L2:  4824.015625
Memory usage L3:  4947.2421875
Memory usage L4:  4947.2421875
Memory usage L5:  4947.2421875
Memory usage L6:  5049.015625
Memory usage L7:  5049.015625
Memory usage L8:  5049.015625
-Memory usage L1:  5049.015625
Memory usage L2:  5049.015625
Memory usage L3:  5157.078125
Memory usage L4:  5157.078125
Memory usage L5:  5157.078125
Memory usage L6:  5224.94140625
Memory usage L7:  5224.94140625
Memory usage L8:  5224.94140625
-Memory usage L1:  5224.94140625
Memory usage L2:  5224.94140625
Memory usage L3:  5348.328125
Memory usage L4:  5348.328125
Memory usage L5:  5348.328125
Memory usage L6:  5450.359375
Memory usage L7:  5450.359375
Memory usage L8:  5450.359375
-Memory usage L1:  5450.359375
Memory usage L2:  5450.359375
Memory usage L3:  5558.1484375
Memory usage L4:  5558.1484375
Memory usage L5:  5558.1484375
Memory usage L6:  5645.08203125
Memory usage L7:  5645.08203125
Memory usage L8:  5645.08203125

(after this, memory does not increase any more)

I was thinking that this is a memory leak issue, but the fact, that some maximum memory threshold is reached indicates that it is not. So I assume I am using or freeing up memory somehow wrongly.
Does anyone have any idea as to why the memory on my ubuntu machine is about 5-7 times higher than the memory usage on my local machine?

I seem to remember that some OS’s allocator don’t return the memory back to the system to speed up later allocations if there is a lot of free memory.
I wouldn’t worry about it as long as it doesn’t cause the system to run out of memory.

I get the problem both on the CPU, and also when using CUDA on the tensors. The problem is, my system (or CUDA) memory does run out of memory.

Could you replace the indexing of your input tensors by:

X_cur = X[train_idx:train_idx + ARG.batch_size, :].detach()
Y_cur = Y[train_idx:train_idx + ARG.batch_size, :].detach()

Also all the del, gc.collect() and empty_cache are useless and will only slow down your program.

I tried that, but it doesn’t seem to help. Apart from that, do i have to call the .to(“cuda”) before or after the detach operation?

What detach() is doing is detaching from previous history. In particular, you don’t want any history corresponding to your dataset because no gradients should be flowing back that way.
You convert to cuda before or after, that does not change much.

If you could get a small code sample to reproduce your problem, that would be very helpful.