class CustomFullyConnectedLayer(nn.Module):
def __init__(self, in_features, out_features, device=None, sparsity = 0.1, diagPos=[], alphaLR=0.01):
super(CustomFullyConnectedLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.total_permutations = max(in_features, out_features)
self.diag_length = min(in_features, out_features)
num_params = in_features * out_features
req_params = int((1-sparsity) * num_params)
K = math.ceil(req_params/min(in_features, out_features))
self.K = K
self.topkLR = alphaLR
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.V = nn.Parameter(torch.empty(self.total_permutations, self.diag_length, device=self.device, dtype=torch.float32, requires_grad=True))
nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))
self.alpha = nn.Parameter(torch.empty(self.total_permutations, device=self.device, requires_grad=True))
nn.init.constant_(self.alpha, 1/self.in_features)
#pdb.set_trace()
assert torch.all(self.alpha >= 0)
def compute_weights(self):
self.alpha_topk = sparse_soft_topk_mask_dykstra(self.alpha, self.K, l=self.topkLR, num_iter=50).to(self.device)
non_zero_alpha_indices = torch.nonzero(self.alpha_topk, as_tuple=False).squeeze()
if non_zero_alpha_indices.dim() == 0:
non_zero_alpha_indices = non_zero_alpha_indices.unsqueeze(0)
WSum = torch.zeros((self.out_features, self.in_features), device=self.device)
for i in non_zero_alpha_indices:
mask1 = get_mask_pseudo_diagonal_torch((self.out_features, self.in_features), sparsity=0.99967, experimentType="randDiagOneLayer", diag_pos=i)
mask1 = mask1.detach()
if self.out_features > self.in_features:
result = self.alpha_topk[i] * torch.matmul(mask1, torch.diag(self.V[i]).to(self.device))
else:
mask1 = mask1.T
result = self.alpha_topk[i] * (torch.matmul(mask1, torch.diag(self.V[i])).T.to(self.device))
WSum += result
return WSum
@property
def weights(self):
return self.compute_weights()
def forward(self, x):
x = x.to(self.device)
W = self.weights
#pdb.set_trace()
out = F.linear(x, W)
return out
def update_alpha_lr(self, new_alpha_lr):
self.topkLR = new_alpha_lr
#print("New learning rate for alpha is: ", self.topkLR)
I have this custom linear layer with V and alpha
as learnable parameters. When I use this layer and replace a single linear layer (in_features=758 and out_features=2304) in ViT-Base model, my memory usage just spikes. I am checking my memory usage using:
print("Moving model to device")
model.to(device=device)
#pdb.set_trace()
if args.channels_last:
model.to(memory_format=torch.channels_last)
print(f"Memory Allocated before summary: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
print(f"Memory Reserved before summary: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
For nn.linear, the two numbers are 541 MB and 629 MB with batch size = 1 and for my custom linear layer the two numbers are 625 MB and 6025 MB with batch size = 1. If I replace all the linear layers with custom, I get OOM error.
I am looking for suggestions to improve my code such that I can reduce memory usage.
FYI, Mask1, and t is a sparse tensor.