RuntimeError: Function MmBackward returned an invalid gradient at index 1 - got [128, 16] but expected shape compatible with [4, 16]

Hello there,
I’m trying to implement (copypaste from different sources) the DDQN algorithm in PyTorch. I get the following error when calling loss.backward(): RuntimeError: Function MmBackward returned an invalid gradient at index 1 - got [128, 16] but expected shape compatible with [4, 16].
I believe this is the relevant code:

        transitions = self.memory.sample()
        batch_state, batch_next_state, batch_action, batch_reward = zip(*transitions)

        batch_state = torch.cat(batch_state)
        batch_next_state = torch.cat(batch_next_state)
        batch_action = torch.cat(batch_action).squeeze(1)
        batch_reward = torch.cat(batch_reward)

        # current Q values are estimated by NN for all actions
        current_q_values = self.policy_net(batch_state).gather(0, batch_action)

        # expected Q values are estimated from actions which gives maximum Q value
        max_next_q_values, _ = self.target_net(batch_next_state).detach().max(0)
        expected_q_values = batch_reward + (self.gamma * max_next_q_values)

        # Compute loss
        loss = self.criterion(current_q_values, expected_q_values)

        # Zero the gradients and backpropagate
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)

        # Optimizer
        self.optimizer.step()

Both policy_net and target_net are just two linear layers with ReLU activation.
The error makes me believe that the graph is messed up somewhere, but I’m not sure how such a situation can be reached. I assume MmBackward refers to a torch.mm computation, but I don’t think I ever call torch.mm.
I am incredible lost and any help would be much appreciated. Thanks in advance!