I’m currently training a GAN that enforces a constraint on the singular values s.t. they are less than or equal to one. This involves me manually reassigning the weights during training time at certain points to enforce a 1-Lipschitz constraint. When I train my model, after a certain period of time it crashes due to running out of memory during the SVD operation. It’s initial memory allocation is around 7-8GB during initialization, but then balloons to 11GB and crashes.
def singular_value_clip(w):
dim = w.shape
if len(dim) > 2:
w = w.reshape(dim[0], -1)
u, s, v = torch.svd(w, some=True)
s[s > 1] = 1
return (u @ torch.diag(s) @ v.t()).view(dim)
Then during training:
for epoch in tqdm(range(start_epoch, epochs)):
if epoch % 5 == 0:
for module in list(dis.model3d.children()) + [dis.conv2d]:
# discriminator only contains Conv3d, Conv2d, BatchNorm3d, and ReLU
if type(module) == nn.Conv3d or type(module) == nn.Conv2d:
module.weight.data = singular_value_clip(module.weight)
elif type(module) == nn.BatchNorm3d:
gamma = module.weight.data
std = torch.sqrt(module.running_var)
gamma[gamma > std] = std[gamma > std]
gamma[gamma < 0.01 * std] = 0.01 * std[gamma < 0.01 * std]
module.weight.data = gamma
I’ve even tried enclosing the contents of the if blocks in a with torch.no_grad():
, but that has no effect (and may potentially just be plain improper). Are there any glaring errors that might be causing this memory leak?