Sparse.mm backward pass

My model involves a forward pass and then determines the gradients of the model output wrt to the model inputs to then calculate an additional term that is to be utilized in my loss function. The calculating of this additional term involves a sparse multiplication which seems to be giving me an error when running loss.backward(): RuntimeError: Expected object of backend CPU but got backend SparseCPU for argument #2 'mat2'. My model looks as follows:

    def forward(self, input_data, fprimes):
        batch_size = self.batch_size
        model_pred = torch.zeros(batch_size, 1).to(self.device)
        for index, value in enumerate(self.values):
            model_inputs = input_data[value][0]
            contribution_index = torch.tensor(input_data[value][1])
            outputs = self.valuespecific_models[index].forward(model_inputs)
            model_pred.index_add_(0, contribution_index, outputs)
            gradients = grad(
                model_pred,
                model_inputs,
                grad_outputs=torch.ones_like(model_pred),
                retain_graph=True,
                #create_graph=True
            )[0]
        #code that organizes gradients and returns a dense matrix, dO_dI
        output_2 = torch.sparse.mm(fprimes.t(), dO_dI.t())

        return model_pred, output_2

Where fprimes is a sparse matrix fed to the model.

When I change fprimes to not be a sparse matrix training seems to works fine. Is there a way to do this with sparse multiplication? Also, when computing the gradients I’m unsure whether I need to do retain_graph or create_graph = True. I need to do one of them otherwise it gives me an error saying I cant backward pass again.

Thanks!