Memory Leak During Singular Value Clipping on Weights

I’m currently training a GAN that enforces a constraint on the singular values s.t. they are less than or equal to one. This involves me manually reassigning the weights during training time at certain points to enforce a 1-Lipschitz constraint. When I train my model, after a certain period of time it crashes due to running out of memory during the SVD operation. It’s initial memory allocation is around 7-8GB during initialization, but then balloons to 11GB and crashes.

def singular_value_clip(w):
    dim = w.shape
    if len(dim) > 2:
        w = w.reshape(dim[0], -1)
    u, s, v = torch.svd(w, some=True)
    s[s > 1] = 1
    return (u @ torch.diag(s) @ v.t()).view(dim)

Then during training:

for epoch in tqdm(range(start_epoch, epochs)):
        if epoch % 5 == 0:
            for module in list(dis.model3d.children()) + [dis.conv2d]:
                # discriminator only contains Conv3d, Conv2d, BatchNorm3d, and ReLU
                if type(module) == nn.Conv3d or type(module) == nn.Conv2d:
                    module.weight.data = singular_value_clip(module.weight)
                elif type(module) == nn.BatchNorm3d:
                    gamma = module.weight.data
                    std = torch.sqrt(module.running_var)
                    gamma[gamma > std] = std[gamma > std]
                    gamma[gamma < 0.01 * std] = 0.01 * std[gamma < 0.01 * std]
                    module.weight.data = gamma

I’ve even tried enclosing the contents of the if blocks in a with torch.no_grad():, but that has no effect (and may potentially just be plain improper). Are there any glaring errors that might be causing this memory leak?

Update, the Cuda memory keeps increasing and I’ve also tried deleting GPU variables after usage. Should I be clearing the cache or something similar to resolve this issue?

Could you remove the usage of the .data attribute, as it might yield unwanted side effects?

When I did that in the past I got an error during assignment, is there something else I should do when I drop the .data attribute?

Assuming you are manipulating the parameters before (or after) each iteration, you could wrap the code into a with torch.no_grad() block and use .copy_ to fill the parameters with the new values.

Thanks for the help! It works now after this and some aggressive use of the del command for no longer needed loss/prediction values.