Custom Linear layer is giving OOM

class CustomFullyConnectedLayer(nn.Module):
    def __init__(self, in_features, out_features, device=None, sparsity = 0.1, diagPos=[], alphaLR=0.01):
        super(CustomFullyConnectedLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.total_permutations = max(in_features, out_features)
        self.diag_length = min(in_features, out_features)

        num_params = in_features * out_features
        req_params = int((1-sparsity) * num_params)
        K = math.ceil(req_params/min(in_features, out_features))

        self.K = K
        self.topkLR = alphaLR
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.V = nn.Parameter(torch.empty(self.total_permutations, self.diag_length, device=self.device, dtype=torch.float32, requires_grad=True))
        nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))

        self.alpha = nn.Parameter(torch.empty(self.total_permutations, device=self.device, requires_grad=True))
        nn.init.constant_(self.alpha, 1/self.in_features)
        #pdb.set_trace()
        assert torch.all(self.alpha >= 0)

    def compute_weights(self):
        self.alpha_topk = sparse_soft_topk_mask_dykstra(self.alpha, self.K, l=self.topkLR, num_iter=50).to(self.device)
        non_zero_alpha_indices = torch.nonzero(self.alpha_topk, as_tuple=False).squeeze()
        
        if non_zero_alpha_indices.dim() == 0:
            non_zero_alpha_indices = non_zero_alpha_indices.unsqueeze(0) 
        
        WSum = torch.zeros((self.out_features, self.in_features), device=self.device)
        
        for i in non_zero_alpha_indices:
            mask1 = get_mask_pseudo_diagonal_torch((self.out_features, self.in_features), sparsity=0.99967, experimentType="randDiagOneLayer", diag_pos=i)
            mask1 = mask1.detach()
            if self.out_features > self.in_features:
                result = self.alpha_topk[i] * torch.matmul(mask1, torch.diag(self.V[i]).to(self.device))
            else:
                mask1 = mask1.T
                result = self.alpha_topk[i] * (torch.matmul(mask1, torch.diag(self.V[i])).T.to(self.device))

            WSum += result
        return WSum

    @property
    def weights(self):
        return self.compute_weights()

    def forward(self, x):
        x = x.to(self.device)
        W = self.weights
        #pdb.set_trace()    
        out = F.linear(x, W)
        return out

    def update_alpha_lr(self, new_alpha_lr):
        self.topkLR = new_alpha_lr
        #print("New learning rate for alpha is: ", self.topkLR) 

I have this custom linear layer with V and alpha as learnable parameters. When I use this layer and replace a single linear layer (in_features=758 and out_features=2304) in ViT-Base model, my memory usage just spikes. I am checking my memory usage using:

print("Moving model to device")
    model.to(device=device)
    #pdb.set_trace()
    if args.channels_last:
        model.to(memory_format=torch.channels_last)
    
    print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
    print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")

For nn.linear, the two numbers are 541 MB and 629 MB with batch size = 1 and for my custom linear layer the two numbers are 625 MB and 6025 MB with batch size = 1. If I replace all the linear layers with custom, I get OOM error.

I am looking for suggestions to improve my code such that I can reduce memory usage.

FYI, Mask1, and t is a sparse tensor.

@ptrblck Do you have any inputs on this? I would appreciate any pointers

I won’t be able to debug the issue as a few methods are undefined (e.g. sparse_soft_topk_mask_dykstra). However, I would start looking into the for loop to check if this loop increases the memory usage significantly in each iteration due to the intermediate forward activations.

Thank you for that. How do you recommend I do that? Just add:

print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")

in the for loop?

And if the memory usage does increase with each loop iteration, is that an expected behavior?

Yes, printing the memory usage inside the loop could work or you can also check it before and after the loop.

Yes, depending on the use case and if a computation graph is created storing intermediates as seen in this simple example:

import torch

print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
# Memory Allocated before summary: 0.0 MB
# Memory Reserved before summary: 0.0 MB

x = torch.randn(1024*16, 1024*16, device="cuda")
print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
# Memory Allocated before summary: 1024.0 MB
# Memory Reserved before summary: 1024.0 MB

# no memory increase (besides the randn_like and output allocation)
for _ in range(10):
    #x = torch.matmul(x, x)
    x = x * torch.randn_like(x)
print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
# Memory Allocated before summary: 1024.0 MB
# Memory Reserved before summary: 3072.0 MB
    
# memory increases as the computation graph needs to store intermediates for the backward
x.requires_grad_(True)
for _ in range(10):
    #x = torch.matmul(x, x)
    x = x * torch.randn_like(x)
print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
# Memory Allocated before summary: 12288.0 MB
# Memory Reserved before summary: 13312.0 MB