torchmetrics: Multi label Confusion matrix different device error

I’m trying to calculate the confusion matrix using torchmetrics for my multi-label output, but I get the following error:

File "/home/antpc/.local/lib/python3.8/site-packages/torchmetrics/metric.py", line 394, in wrapped_func
    raise RuntimeError(
RuntimeError: Encountered different devices in metric calculation (see stacktrace for details).This could be due to the metric class not being on the same device as input.Instead of `metric=ConfusionMatrix(...)` try to do `metric=ConfusionMatrix(...).to(device)` where device corresponds to the device of the input.

My code:

from torchmetrics import ConfusionMatrix
def calculate_metrics(predictions, targets):
	cm = ConfusionMatrix(num_classes=34, multilabel=True)
	matrix = cm(predictions, targets)
	return matrix

Then I tried to change my code as:

from torchmetrics import ConfusionMatrix
def calculate_metrics(predictions, targets):
	cm = ConfusionMatrix(num_classes=34, multilabel=True).to(device='cpu')
	matrix = cm(predictions.detach().cpu(), targets.detach().cpu())
	return matrix

Still it shows the same error. Can anyone help me out with this?

Could you post a minimal, executable code snippet by adding the missing definitions, which would reproduce the issue, please?

The code is written in pytorch lightning.

from torch import optim, nn
import pytorch_lightning as pl
from torchmetrics import ConfusionMatrix

class ModelClassifier(pl.LightningModule):
    def __init__(self):
        super(ModelClassifier, self).__init__()
        self.model = nn.Linear(3*512*512 ,out_features=34)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.cm = ConfusionMatrix(num_classes=34, multilabel=True).to(device='cpu')

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = x.view(batch_size,-1)
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.01)
        return optimizer
    
    def loss_func(self, pred, labels):
        return self.loss_fn(pred,labels)
    
    def calculate_metrics(self, pred, labels):
        confusion_matrix = self.cm(pred.detach().cpu(), labels.detach().cpu())
        return confusion_matrix

    def training_step(self, batch, batch_idx):
        inputs, labels = batch

        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels.float())
        metrics = self.calculate_metrics(outputs, labels)
        
        return loss

model = ModelClassifier()

trainer = pl.Trainer(strategy='dp', max_epochs=150, gpus=8, fast_dev_run=True)

trainer.fit(model, train_loader)

Here the train_loader contains images of size: (3 X 512 X 512) with a batch size of 32.