'out=... arguments don't support automatic differentiation' when using num_workers > 0

Hello everyone, I’ve found myself having the following problem:
when calling

trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle = True, num_workers=8, worker_init_fn=np.random.seed(0), pin_memory=True)

the following pops up:

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/iferfoglia/mambaforge/envs/conceptSTL/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/iferfoglia/mambaforge/envs/conceptSTL/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/iferfoglia/mambaforge/envs/conceptSTL/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/iferfoglia/mambaforge/envs/conceptSTL/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>
    return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
  File "/home/iferfoglia/mambaforge/envs/conceptSTL/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 141, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

If I don’t define the number of workers and it defaults to 0, everything works.

The training loop:

for epoch in range(epochs):
        for batch, labels in trainloader:
            batch, labels = batch.to(device), labels.to(device)
            y_preds = model(batch)
            loss = criterion(y_preds, labels.float())
            # instead of optimizer.zero_grad()
            for param in model.parameters():
                param.grad = None

and the model:

class DCR_classifier(nn.Module):
    def __init__(self, input_shape):
        # Define your MLP layers for classification
        self.fc1 = nn.Linear(input_shape, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)
    def forward(self, x, traj_emb):
        x_new = x.view(x.size(0), -1) # (batch, nvar, points) -> (batch, nvar x points)
        combined_features = torch.cat((traj_emb, x_new), dim=1) # (batch, kemb_size + nvar x points)
        output = self.fc1(combined_features)
        output = F.relu(output)
        output = self.fc2(output)
        output = F.relu(output)
        output = self.fc3(output)
        return output
class DCR(nn.Module):
    def __init__(self, nvar, points, device):
        self.concepts = load_phis_dataset('concepts', nvar)
        self.kemb = get_kernel_embedding(self.concepts, nvar).to(device) # (concepts, kemb_size)
        _ = self.kemb.requires_grad_()
        self.classifier = DCR_classifier(self.kemb.shape[1] + (nvar*points))
        self.nvar = nvar
    def forward(self, x):
        # concept truth degrees
        rhos = get_robustness(x, self.concepts, time = False) # (trajectories, concepts)
        _ = rhos.requires_grad_()
        # embed trajectories in kernel space
        traj_emb = torch.matmul(rhos, self.kemb) # (trajectories, kemb_size)
        _ = traj_emb.requires_grad_()
        output = self.classifier(x, traj_emb)
        return output.squeeze(1)

Because I’m working with custom objects, I need to call the requires_grad(), otherwise it doesn’t work.
I can’t disclose all the called functions, but they shouldn’t be the problem.

Thank you!

Could you describe thus use case in more detail?
Calling .requires_grad_() on inputs would be needed if you want to train the inputs in e.g. an adversarial training scenario. I don’t understand the “custom object” part and why it would be needed to backpropagate into it.

I call it because I need, afterwards, to see the gradients on the inputs to achieve explainability of the output. I am working with STL formulae, which are my “custom objects” and are not differentiable.

Thanks for the clarification. Could you post a minimal and executable code snippet reproducing the issue?