Hi thanks for your quick response! Here is the model portion of the code with some added comments and print lines which I hope will be useful (see screenshot below more legible version), any tips or things that you see that could reduce the memory usage would be very helpful!
#---------------------------------GPU-0------------------------------------
print("START CALC ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
#PART-1: EVALUATE NN
param1=(torch.sigmoid(x1.mm(torch.t(w1))+torch.t(b1.mm(M11)))-0.5).mm(torch.t(w2))+torch.t(b2.mm(M11))
print(x1.shape,w1.shape,b1.shape,w2.shape,b2.shape,M11.shape,param1.shape) #size of NN related tensors
print("MEMORY CHECK-1 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
#PART-2: USE NN OUTPUT AS PARAM FOR PINN MODEL
I5=torch.ones([max_nn1,1]).type(dtype) #used to compute sums
print(natom1,max_nn1)
print(I5.shape,fc_rijk1.shape,rijk1.shape,cosjk1.shape,rij1.shape,fc_rij1.shape) #size of tensors used in PINN calc
print("MEMORY CHECK-2 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
Sij=(torch.prod(1.0-fc_rijk1*torch.exp(-((rijk1.view(natom1,max_nn1*max_nn1))*(param1[:,6].view(natom1,1))**2.0).view(natom1*max_nn1,max_nn1)),1)).view(natom1*max_nn1,1)
print("MEMORY CHECK-3 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
zij=T1*(fc_rij1.view(natom1*max_nn1,1));
print("MEMORY CHECK-4 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
zij=zij*Sij*((cosjk1.view(natom1,max_nn1*max_nn1)-param1[:,5].view(natom1,1)).view(max_nn1*natom1,max_nn1))**2.0
print("MEMORY CHECK-5 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
zij=torch.cat(torch.split(zij,max_nn1),1)
print("MEMORY CHECK-6 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
zij=torch.t(zij)
print("MEMORY CHECK-7 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
bij=(1.0+((param1[:,4].view(natom1,1))**2.0)*(zij.mm(I5).view(natom1,max_nn1)))**(-0.5)
print("MEMORY CHECK-8 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
del zij
print("MEMORY CHECK-9 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
#BEYOND THIS POINT ADDITIONAL MEMORY USE IS NEGLIABLE
Sij=Sij.view(natom1,max_nn1)
Ep=-param1[:,7].view(natom1,1)*((Sij*bij*fc_rij1).mm(I5)).pow(0.5)
Epair=0.5*((torch.exp(param1[:,0].view(natom1,1).repeat(1,max_nn1)-param1[:,1].view(natom1,1)*rij1)-Sij*bij*torch.exp(param1[:,2].view(natom1,1).repeat(1,max_nn1)-param1[:,3].view(natom1,1)*rij1))*fc_rij1).mm(I5)
print("MEMORY CHECK-10 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
E1=Ep+Epair
print("MEMORY CHECK-11 ",10.0**(-9.0)*torch.cuda.memory_allocated(0),' GB')
#-------------------------------------------------------------------------
I should note that the model is getting the correct outputs (as compared to a separate independent MPI code) but the GPU memory use is making the pytorch implementation difficult. I know that using smaller batches is an option but I wanted to avoid this if possible because the current implementation is quite fast.
The only tensors that require grad are w1,b1,w2,b2, here is the output of the above portion of the code when I do a single forward pass with “requires_grad=true”
START CALC 5.777108480000001 GB
torch.Size([25609, 60]) torch.Size([32, 60]) torch.Size([32, 1]) torch.Size([8, 32]) torch.Size([8, 1]) torch.Size([1, 25609]) torch.Size([25609, 8])
MEMORY CHECK-1 5.784743936 GB
25609 111
torch.Size([111, 1]) torch.Size([2842599, 111]) torch.Size([2842599, 111]) torch.Size([2842599, 111]) torch.Size([25609, 111]) torch.Size([25609, 111])
MEMORY CHECK-2 5.7847444480000005 GB
MEMORY CHECK-3 8.320697344000001 GB
MEMORY CHECK-4 9.582920704000001 GB
MEMORY CHECK-5 14.631814144000002 GB
MEMORY CHECK-6 15.894037504000002 GB
MEMORY CHECK-7 15.894037504000002 GB
MEMORY CHECK-8 15.928350208000001 GB
MEMORY CHECK-9 15.928350208000001 GB
MEMORY CHECK-10 16.008687616 GB
MEMORY CHECK-11 16.008790528000002 GB
and here is the output when “requires_grad=false”
START CALC 5.777108480000001 GB
torch.Size([25609, 60]) torch.Size([32, 60]) torch.Size([32, 1]) torch.Size([8, 32]) torch.Size([8, 1]) torch.Size([1, 25609]) torch.Size([25609, 8])
MEMORY CHECK-1 5.777928192 GB
25609 111
torch.Size([111, 1]) torch.Size([2842599, 111]) torch.Size([2842599, 111]) torch.Size([2842599, 111]) torch.Size([25609, 111]) torch.Size([25609, 111])
MEMORY CHECK-2 5.777928704000001 GB
MEMORY CHECK-3 5.789331968 GB
MEMORY CHECK-4 7.051555328 GB
MEMORY CHECK-5 7.051555328 GB
MEMORY CHECK-6 7.051555328 GB
MEMORY CHECK-7 7.051555328 GB
MEMORY CHECK-8 7.062958592 GB
MEMORY CHECK-9 5.800735232 GB
MEMORY CHECK-10 5.800941056 GB
MEMORY CHECK-11 5.801043968 GB
Also I should note that this is a slightly different implementation than what I mentioned in the previous comment, that implementation only used around 4 GB of memory (with required_grad=false) but was about 2X slower so I’ve slightly tweaked what I’m doing.
Finally the code is sort of hard to read in the text form so here is a screenshot which might be easier