How to reduce the memory requirement for a GPU pytorch training process? (finally solved by using multiple GPUs)

(Peter Xiao Guo) #1


I’m new to torch 0.4 and implement a Encoder-Decoder model for image segmentation.

during training to my lab server with 2 GPU cards only, I face the following problem say “out of memory”:


my input is 320*320 image and even I let batch_size = 1, it cannot finish even 1 epoch, I’m not sure whether there is some commands to use multiple GPU card?

Any suggestion is appreciated!

Thank you very much!


(Konpat Ta Preechakul) #2

This is weird. To my understanding each iteration should be “separated”. Your task is clearly not recurrent and your batch is clearly not increasing nor decreasing. So, why then you can finish a few iterations before getting out of memory error?

Must there be some kind of memory leaking? Some codes sure help!

(Peter Xiao Guo) #3

Hi, Konapt, Thank you very much!

I resize my input batch as (32 , 3,64,64) and 128 by 128 both works well.

When I increase each image size to 224 by 224, 256 by 256… fails… it can hold ~20 epochs and tell my out of memory.

I head the term " memory leaking" but could you please give me some suggestions on these? Thank you very much. I’m not sure how to do it in pytorch and how to check it.


(Konpat Ta Preechakul) #4

Actually, I need to see the code. Anyway, how did you get the “train_loss”, is it a kind of average? If you get the average by adding the “loss” up something like:

loss = ((y_hat - y) ** 2).mean()
sum_loss += loss
avg_loss = sum_loss / itr
print('train_loss', float(avg_loss))

Memory leaking comes from the sum_loss, since it will hold the graph for each iteration from the first iteration. Using .detach() should help in this case, but frankly you could just use float(loss) altogether.

(Peter Xiao Guo) #5

Thank you Konapt,

you are a genius !!!

I really use sum_loss += loss indeed!

should I use sum_loss += float(loss) to solve this problem??

thank you !

(Naman Jain) #6

Yeah it’ll work, go for it.
And you got the the reason for it right (graph storing activations from old epochs which was not you wanted)!

(Peter Xiao Guo) #7

Hi Naman,

I try Konapat’s suggestion and still face the “out of memory” at 20 epoch… any idea to help it? I’m using pytorch 0.4 now.

thank you very much!


(Peter Xiao Guo) #8

Hi Konpat,

Thank you your advice very much!

I try your suggestion and add float() to all loss.item during iteration and face the same “out of memory error” at 20 epoch.

I’m using torch 0.4 now and seems there is no .detach() to be used… I feel puzzled on it, any suggestion is appreciated.

Thank you very much!



You don’t need to use float(loss.item()). .item() returns a standard python float and thus won’t be tracking the computation graph.

How is your GPU memory usage in the first epochs? Are you already approx. on the limit?
Could you observe the memory with nvidia-smi -l and see if the memory usage is increasing in the first epochs?

(Konpat Ta Preechakul) #10

I have no idea now :sweat_smile:, one last thing to look at is the train loader I think.

(Peter Xiao Guo) #11

Hi ptrblck,

Thank you very much! I try your command and get the following information:

I notice that before the epoch to tell me out of memory, as shown in the blue line, GeForce GTX card seems to be almost doubled suddenly and it fails:

Is there any way in pytorch to use the another GPU card in sequence to satisfy it?

Or should I need to reduces number CONV layer or number of channels in my model?

Thank you very much!


(Peter Xiao Guo) #12

Hi Konpat,

thank you very much anyway! I follow ptblck’s advice to check nvidia’s usage and find during 20th epoch, in one of up-sampling layers, when i do skip-connection operation to concatenate 2 layers from encoder and decoder layer like in U-Net, the memory required for GPU just doubled and it therefore fails:

I may try to find a way to use another GPU card parallel during training or reduce my channels or CONV layers’ number in my model.

Thank you very much!



What doesn’t really make sense, is that the operation apparently needs more memory in the 20th epoch, whereas before that it seems to work.
Model sharding would be an approach to split your model onto both GPUs without sacrificing the model capacity.

Since I still think there might be an issue with your code, could you please post it?
You can add code by wrapping it in three `.

(Konpat Ta Preechakul) #14

Honestly, if your train loader is not provided directly by the library, I really want to see it.

(Peter Xiao Guo) #15

Hi ptrblck,

My implementation to create the dataloader may be naive and not clean to you.

1: To get data and label as numpy arrays: I use packages to read .dcm files and corresponding .nrrd files

2: I convert the numpy array (optionally crop or not) from defalut type unit16 to int16 in order to be used in pytorch:

below is one of the repetitive part for me to read .dcm file to get image array and convert a (320, 320) original image numpy array into (number of totoal image files,320,320):

''' > img_1 = [pydicom.dcmread(train_Prostate3T_img_path + '/' + ID + '/' + dcm_number).pixel_array for ID in Prostate3T_patient_ID for dcm_number in os.listdir(train_Prostate3T_img_path + '/' + ID )]
> Prostate3T_img = img_1[0][np.newaxis,...]
> count = len(Prostate3T_img)
> for i in range(1,len(img_1)):
>     try:
>         Prostate3T_img = np.vstack((Prostate3T_img, img_1[i][np.newaxis,...]))
>     except:
>         #print("mis-matched dimension at", i, "-th sample.")
>         #print("wrong shape:", np.shape(img_1[i])) ## 18 wrong 256 x 256 shapes
>         #print("prostate samples already counted:", count)
>         continue
> print("original total samples should be counted:", np.shape(img_1)[0])
> print("Prostate3T image stacked shape: ", np.shape(Prostate3T_img))
> print("==="*3) '''
> Then I convert (1,320,320) per sample to (3,320,320) 3 channels using cv2.cvtColor:
> ''' img_to_3_channels_Prostate3T = np.array([cv2.cvtColor(Prostate3T_img[i], cv2.COLOR_GRAY2RGB).T for i in range(len(Prostate3T_img))])

3: convert them into torch tensor and use .utils.TensorDataset to put in pytorch dataset
(the reason why I convert tensor type to .type(‘torch.FloatTensor’) is that I found I can only use L1 Loss in this way an gives no error, my custom loss function is still under deleloped)

img_to_3_channels_Prostate3T_int16 = np.array(img_to_3_channels_Prostate3T, dtype=np.int16)
img_Prostate3T_tensor = torch.from_numpy(img_to_3_channels_Prostate3T_int16)
img_Prostate3T_tensor = img_Prostate3T_tensor.type('torch.FloatTensor')
img_Prostate3T_tensor =
Prostate3T_dataset = utils.TensorDataset(img_Prostate3T_tensor, label_Prostate3T_tensor)
  1. then use utils.DataLoader to establish a Dataloader with SubsetRandomSampler.

4.1 SubsetRandomSampler is used in this way:
train_sampler = SubsetRandomSampler(np.arange(n_training_samples, dtype=np.int64))

4.2. training set loader is used in this way:

train_loader =, batch_size=batch_size, sampler=train_sampler)

Any advice is appreciated!

Thank you very much!


(Peter Xiao Guo) #16

Hi Konpat,

As I replied to ptrblck, I attached my codes and my explanation on the codes above for your reference.

I just follow an online tutorial to get (320,320) numpy array firstly, and increase one dimension for each sample to (1,320,320) and convert them into 3 channels (3,320,320) and then concatenate all samples (# of total samples, 3, 320, 320).

And then convert (# of total samples, 3, 320, 320) from dtype=unit 16 to int16 in order to be dealt in pytorch and convert is type from numpy array to torch tensor.

After that, I use **utils.TensorDataset ** and SubsetRandomSampler to call

train_loader =, batch_size=batch_size, sampler=train_sampler)

Any suggestion is appreciated!

Thank you very much!



The data loading part looks OK. You could change some small details, e.g. use lazy loading, but this is unrelated to your OOM error.
Probably we would have to see the other parts of your training procedure.

(Peter Xiao Guo) #18

Hi ptrblck,

Thank you very much! I will modify and re-construct my code today and let you know the more reader-friendly code. (: ), hope you also enjoy tonight’s worlds cup!)

Thank you!


(Peter Xiao Guo) #19

Hi ptrblck,

hope you enjoy these-2-day world cup matches.

I remove repetitive codes and try to clean my code.

For training part, below is my code:

similar to validation part as below:

Any suggestion is appreciated!

Thank you very much!

(Peter Xiao Guo) #20

Hi ptrblck,

I solve the question posted here by using:

@voxmenthe ‘s answer from a multiple GPUs’ solution:

model = <specify model here>
model = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))

I notice you mentioned “it splits the data/batch onto different GPUs” rather than model sharding… I feel puzzled on this statement. What’s the advantage to split model class on different GPUs? Does it mean it helps to distribute model training burden to multiple PCs (rather one PC only?)

Thank you very much!