Need quick help with an optimizer.step() error

Hi! I am quite new to this so please do tell me what more can I share. I am running into an error with optimizer.step() in my code. I have implemented my neural network by creating my own sparse linear layers,

class SparseLinear(nn.Module):
    def __init__(self,input_size,output_size,sparse_indices,bias_indices,w_bool= True,b_bool = True):
        super(SparseLinear,self).__init__()
        self.weight = nn.Parameter(data = torch.sparse.FloatTensor(sparse_indices, torch.randn(sparse_indices.shape[1],device = device)*torch.sqrt(torch.tensor(2.0/input_size)),[output_size, input_size]), requires_grad= w_bool)
        if (b_bool == True):
            self.bias = nn.Parameter(data = torch.sparse.FloatTensor(bias_indices, torch.randn(bias_indices.shape[1],device = device),[output_size, H.shape[1]]), requires_grad = b_bool)
        else:
            self.bias = nn.Parameter(data = torch.zeros([output_size, H.shape[1]],device = device), requires_grad = b_bool)
    def forward(self, x, y):
        var = torch.sparse.mm(self.bias,y.t())
#         var = var.to(device)
        return torch.sparse.addmm(var, self.weight, x.t()).t()

class Neural_network(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(Neural_network,self).__init__()
        self.l1 = SparseLinear(input_size,hidden_size,sparse_indices_in, bias_indices_in, b_bool= False)

        self.l2 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l3 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l4 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l5 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l6 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l7 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)
        self.l8 = SparseLinear(hidden_size,hidden_size,sparse_indices_mid, bias_indices_mid)

        self.l9 = SparseLinear(hidden_size,output_size,sparse_indices_out, bias_indices_out)

    def forward(self,x):

        # llr2 = LLR_add(x)
        # with torch.no_grad():
        #     llr3 = llr2

        out = self.l1(x,x)
        out = hyperbolic_tan(out)
        out = check_node_op(out) #Output of Check Nodes

        out = self.l2(out,x)
        out = hyperbolic_tan(out)
        out = check_node_op(out)

        out = self.l3(out,x)
        out = hyperbolic_tan(out)
        out = check_node_op(out)

        out = self.l4(out,x)
        out = hyperbolic_tan(out)
        out = check_node_op(out)

        out = self.l5(out,x)
        out = hyperbolic_tan(out)
        out = check_node_op(out)

#         out = self.l6(out,x)
#         out = hyperbolic_tan(out)
#         out = check_node_op(out)

#         out = self.l7(out,x)
#         out = hyperbolic_tan(out)
#         out = check_node_op(out)

#         out = self.l8(out,x)
#         out = hyperbolic_tan(out)
#         out = check_node_op(out)

        out = self.l9(out,x)


        return out

model = Neural_network(input_size,hidden_size,output_size)

criterion = nn.BCEWithLogitsLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) # SGD without momentum
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum = 0.8) # SGD with momentum

PATH = "Models_MTP/NotFC_Decoder_BCH_63_45_5iter_5db.pth"
def save_model(PATH):
    torch.save(model.state_dict(), PATH)

n_total_steps = len(train_dataloader)
best_epoch = 0
best_ber_sim = 0.2
loss_saved = 1
mean_losses = []
k = columns - rows
for epoch in range(num_epochs):
    losses = []
    for i , (feature, labels) in enumerate(train_dataloader):

        feature = feature.to(device)
        # labels = labels.to(torch.long)
        labels = labels.to(device)
        model = model.to(device)
        criterion = criterion.to(device)

        # Forward pass
        output = model(feature)
        # raise Exception('break here')
        loss = criterion(output,labels)
        losses.append(loss.item())

        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        # if (loss.item() < 0):
        #     print(output)
        #     print(labels)
        #     raise Exception('break here')
        optimizer.step()


        if (i+1)%(4) == 0:
            print(f'epoch {epoch+1}/{num_epochs} steps {i+1}/{n_total_steps} loss {loss.item():.4f}')

This is the error that I am getting:

RuntimeError                              Traceback (most recent call last)
Cell In [29], line 30
     25 loss.backward()
     26 # if (loss.item() < 0):
     27 #     print(output)
     28 #     print(labels)
     29 #     raise Exception('break here')
---> 30 optimizer.step()
     33 if (i+1)%(4) == 0:
     34     print(f'epoch {epoch+1}/{num_epochs} steps {i+1}/{n_total_steps} loss {loss.item():.4f}')

File ~/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:113, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    111 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
    112 with torch.autograd.profiler.record_function(profile_name):
--> 113     return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/optim/sgd.py:146, in SGD.step(self, closure)
    143         else:
    144             momentum_buffer_list.append(state['momentum_buffer'])
--> 146 sgd(params_with_grad,
    147     d_p_list,
    148     momentum_buffer_list,
    149     weight_decay=group['weight_decay'],
    150     momentum=group['momentum'],
    151     lr=group['lr'],
    152     dampening=group['dampening'],
    153     nesterov=group['nesterov'],
    154     maximize=group['maximize'],
    155     has_sparse_grad=has_sparse_grad,
    156     foreach=group['foreach'])
    158 # update momentum_buffers in state
    159 for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):

File ~/.local/lib/python3.10/site-packages/torch/optim/sgd.py:197, in sgd(params, d_p_list, momentum_buffer_list, has_sparse_grad, foreach, weight_decay, momentum, lr, dampening, nesterov, maximize)
    194 else:
    195     func = _single_tensor_sgd
--> 197 func(params,
    198      d_p_list,
    199      momentum_buffer_list,
    200      weight_decay=weight_decay,
    201      momentum=momentum,
    202      lr=lr,
    203      dampening=dampening,
    204      nesterov=nesterov,
    205      has_sparse_grad=has_sparse_grad,
    206      maximize=maximize)

File ~/.local/lib/python3.10/site-packages/torch/optim/sgd.py:233, in _single_tensor_sgd(params, d_p_list, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, has_sparse_grad)
    231     momentum_buffer_list[i] = buf
    232 else:
--> 233     buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
    235 if nesterov:
    236     d_p = d_p.add(buf, alpha=momentum)

RuntimeError: set_indices_and_values_unsafe is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
    x.data.set_(y)
to:
    with torch.no_grad():
        x.set_(y)

This error only comes when I am trying optimizer to be SGD with momentum. Can somebody say how do I go about debugging it.