I have the following function inside a class:
def parabola(self):
# Initialize u1, u2, and u3 with zeros
u1 = torch.cat((torch.diff(self.ubar.detach(), dim=0).detach().clone(), torch.zeros_like(self.ubar[:1,:,:,:], dtype = torch.float32)), dim = 0)
u2 = torch.cat((torch.diff(self.ubar.detach(), dim=1).detach().clone(), torch.zeros_like(self.ubar[:,:1,:,:], dtype = torch.float32)), dim = 1)
u3 = torch.cat((torch.diff(self.ubar.detach(), dim=3).detach().clone(), torch.zeros_like(self.ubar[:,:,:,:1], dtype = torch.float32)), dim = 3)
# Stack the values of mu1 and mu2 for each value of z in the loop
mu1sum = torch.stack([self.mu1[:,:,:,self.k_indices[z]].sum(dim=-1) for z in range(self.l)], dim=-1)
mu2sum = torch.stack([self.mu2[:,:,:,self.k_indices[z]].sum(dim=-1) for z in range(self.l)], dim=-1)
# Calculate u1, u2, and u3 using the formulas in the original function
u1 = self.p1 + self.sigmap * (u1 + mu1sum)
u2 = self.p2 + self.sigmap * (u2 + mu2sum)
u3 = self.p3 + self.sigmap * u3
#B = self.bound(u1,u2,self.lmbda,z+1,self.l,img)
# Calculate B using self.bound and broadcast it along the last dimension
k_indices = torch.arange(1, self.l+1, dtype=torch.int64).repeat(self.h, self.w, self.nc, 1)
B = self.bound(u1, u2, self.lmbda, k_indices, self.l, self.f.unsqueeze(-1).repeat(1,1,1,self.l))
# Use mask to select elements where u3 < B
mask = u3 < B
img = self.f.unsqueeze(self.f.dim()-1).expand(mask.shape)[mask]
self.p1[~mask] = u1[~mask]
self.p2[~mask] = u2[~mask]
self.p3[~mask] = u3[~mask]
which is called in a loop to recursively update self.p1, self.p2, self.p3:
p1 = torch.zeros((self.h,self.w,self.nc,self.l,), dtype = torch.float32)
p2 = torch.zeros((self.h,self.w,self.nc,self.l,), dtype = torch.float32)
p3 = torch.zeros((self.h,self.w,self.nc,self.l,), dtype = torch.float32)
for iter in range(self.repeats):
p1, p2, p3 = self.parabola(p1, p2, p3, self.ubar) # project onto parabola (set K)s
All tensors are created on the GPU with a Mac M1 chip (“mps”). This loop is giving me substantial memory leak, with the GPU memory basically increasing linearly and by several order of magnitude over the course of ~500 iterations.
Any suggestions on what may be causing this? I have tried to detach and clone pretty much everything in this code, to no avail. Autograd is globally disabled.