I have this custom layer:
class CustomFullyConnectedLayer(nn.Module):
def __init__(self, in_features, out_features, device=None, sparsity = 0.1, diagPos=[], alphaLR=0.01):
super(CustomFullyConnectedLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.total_permutations = max(in_features, out_features)
self.diag_length = min(in_features, out_features)
print("Sparsity is: ", sparsity)
num_params = in_features * out_features
req_params = int((1-sparsity) * num_params)
K = math.ceil(req_params/min(in_features, out_features))
self.K = K
self.topkLR = alphaLR
self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.V = nn.Parameter(torch.empty(self.total_permutations, self.diag_length, device=self.device, dtype=torch.float32, requires_grad=True))
nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))
self.alpha = nn.Parameter(torch.empty(self.total_permutations, device=self.device, requires_grad=True))
nn.init.constant_(self.alpha, 1/self.in_features)
#pdb.set_trace()
assert torch.all(self.alpha >= 0)
def compute_weights(self):
self.alpha_topk = sparse_soft_topk_mask_dykstra(self.alpha, self.K, l=self.topkLR, num_iter=50).to(self.device)
non_zero_alpha_indices = torch.nonzero(self.alpha_topk, as_tuple=False).squeeze()
#print("memory after alpha_topk:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
if non_zero_alpha_indices.dim() == 0:
non_zero_alpha_indices = non_zero_alpha_indices.unsqueeze(0)
WSum = torch.zeros((self.out_features, self.in_features), device=self.device)
#print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
results = []
for i in non_zero_alpha_indices:
#print("Iteration: ", i)
mask1 = get_mask_pseudo_diagonal_torch((self.out_features, self.in_features), sparsity=0.99967, experimentType="randDiagOneLayer", diag_pos=i)
#mask1 = mask1.detach()
#print("Memory after mask1:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
V_scaled = self.V[i] * self.alpha_topk[i]
#print("Memory after V_scaled:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
#Using V_scaled make a sparse matrix equivalent to doing torch.diag(self.V_scaled) in the COO format
n = V_scaled.size(0)
indices = torch.arange(0, n, device=self.device).unsqueeze(0).repeat(2, 1) # Create (i, i) indices
values = V_scaled # Diagonal values are just V_scaled
V_scaled_sparse = torch.sparse_coo_tensor(indices, values, (n, n), device=self.device)
with torch.cuda.amp.autocast(enabled=False):
if self.out_features > self.in_features:
#WSum += self.alpha_topk[i] * torch.matmul(mask1, torch.diag(self.V[i]).to(self.device))
#WSum += torch.matmul(mask1, torch.diag(V_scaled).to(self.device))
WSum += torch.sparse.mm(mask1, V_scaled_sparse.float())
#print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
#WSum += self.alpha_topk[i] * torch.einsum('ij,j->ij', mask1, self.V[i])
else:
mask1 = mask1.T
#WSum += self.alpha_topk[i] * (torch.matmul(mask1, torch.diag(self.V[i])).T.to(self.device))
#WSum += torch.matmul(mask1, torch.diag(V_scaled)).T.to(self.device)
WSum += torch.sparse.mm(mask1, V_scaled_sparse.float()).T
#print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
return WSum
@property
def weights(self):
return self.compute_weights()
def forward(self, x):
x = x.to(self.device)
W = self.weights
#pdb.set_trace()
out = F.linear(x, W)
return out
def update_alpha_lr(self, new_alpha_lr):
self.topkLR = new_alpha_lr
#print("New learning rate for alpha is: ", self.topkLR)
And my mask is being generated by using the following function:
def get_mask_random_torch(mask_shape, sparsity, device='cuda'):
num_elements = mask_shape[0] * mask_shape[1]
num_ones = int(sparsity * num_elements)
# Generate random indices
indices = torch.randperm(num_elements, device=device)[:num_ones]
# Convert flat indices to 2D indices
rows = indices // mask_shape[1]
cols = indices % mask_shape[1]
# Create a sparse tensor using the computed rows and cols
sparse_indices = torch.stack([rows, cols], dim=0)
values = torch.ones(num_ones, device=device)
# Create the sparse COO tensor
sparse_mask = torch.sparse_coo_tensor(sparse_indices, values, size=mask_shape, device=device)
return sparse_mask
For the time being, let us ignore the implementation of sparse_soft_topk_mask_dykstra
.
When I use this layer in a network, I am getting the following error:
File "/gpfs/fs2/scratch/atyagi2/pytorch-image-models/train.py", line 1067, in _backward
loss_scaler(
File "/gpfs/fs2/scratch/atyagi2/pytorch-image-models/timm/utils/cuda.py", line 62, in __call__
self._scaler.scale(loss).backward(create_graph=create_graph)
File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/_tensor.py", line 521, in backward
torch.autograd.backward(
File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/autograd/__init__.py", line 289, in backward
_engine_run_backward(
File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: mat2_.is_sparse() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/sparse/cuda/SparseMatMul.cu":785, please report a bug to PyTorch.
If I understand correctly, torch.sparse.mm
supports backpropagation when both matrices are of COO format. So, I am unsure why I should be getting this error?
My pytorch version is: 2.4.0+cu121
and which nvcc
returns /software/cuda/12.1/bin/nvcc
Am I right to assume that the error is happening because of using torch.sparse.mm? Any inputs on how this can be avoided?