I’m currently training a GAN that enforces a constraint on the singular values s.t. they are less than or equal to one. This involves me manually reassigning the weights during training time at certain points to enforce a 1-Lipschitz constraint. When I train my model, after a certain period of time it crashes due to running out of memory during the SVD operation. It’s initial memory allocation is around 7-8GB during initialization, but then balloons to 11GB and crashes.
def singular_value_clip(w): dim = w.shape if len(dim) > 2: w = w.reshape(dim, -1) u, s, v = torch.svd(w, some=True) s[s > 1] = 1 return (u @ torch.diag(s) @ v.t()).view(dim)
Then during training:
for epoch in tqdm(range(start_epoch, epochs)): if epoch % 5 == 0: for module in list(dis.model3d.children()) + [dis.conv2d]: # discriminator only contains Conv3d, Conv2d, BatchNorm3d, and ReLU if type(module) == nn.Conv3d or type(module) == nn.Conv2d: module.weight.data = singular_value_clip(module.weight) elif type(module) == nn.BatchNorm3d: gamma = module.weight.data std = torch.sqrt(module.running_var) gamma[gamma > std] = std[gamma > std] gamma[gamma < 0.01 * std] = 0.01 * std[gamma < 0.01 * std] module.weight.data = gamma
I’ve even tried enclosing the contents of the if blocks in a
with torch.no_grad():, but that has no effect (and may potentially just be plain improper). Are there any glaring errors that might be causing this memory leak?