Unexpected behavior when using batch_jacobian with multiple inputs/outputs in quantum-classical neural network

I’m implementing a neural network that includes quantum layers (using PennyLane’s pytorch’s interface) to solve ODEs. I want to encode several points at once and get several results of the ODE (one result per corresponding input) at the same run.
Currently I’m running a toy model where I try to get u(x)=sin(x) according to the loss function: du_dx - cos(x).

The network structure is:

network structure and quantum circuit
class QuantumNetworkBaseAngleClass(nn.Module):
    def __init__(self, n_qubits=8, n_layers=2, x_scale_factor=1.0, x_scale_bias=0.0):
        super().__init__()
        self.n_qubits = n_qubits 
        self.n_layers = n_layers
        self.weight_shapes = None
        self.x_scale_factor = x_scale_factor
        self.x_scale_bias = x_scale_bias
        self.q_layer, self.qnode = self.create_hybrid_network()

        # Initialize the weights of the quantum layer
        for param in self.q_layer.parameters():
            if param.requires_grad:  # Only initialize trainable parameters
                nn.init.uniform_(param, 0.0, torch.pi)

    def create_hybrid_network(self):
        dev = qml.device("default.qubit", wires=self.n_qubits)
        layer = qml.StronglyEntanglingLayers     
        shape = layer.shape(n_layers=self.n_layers, n_wires=self.n_qubits)
        self.weight_shapes = {"weights": shape}

        @qml.qnode(dev, interface="torch", diff_method="backprop", differentiable=True, max_diff=2)
        def qnode(inputs, weights):
            scaled_inputs = (inputs - self.x_scale_bias) * self.x_scale_factor
            qml.AngleEmbedding(scaled_inputs, wires=range(self.n_qubits))
            layer(weights, wires=range(self.n_qubits))
            return [qml.expval(qml.PauliZ(wires=i)) for i in range(self.n_qubits)]
        return qml.qnn.TorchLayer(qnode, self.weight_shapes), qnode

    def forward(self, inputs):
        output = self.q_layer(inputs)
        return output

class QuantumNetworkAngleClass(nn.Module):
    def __init__(self, n_qubits=8, n_layers=2, x_scale_factor=1.0, x_scale_bias=0.0):
        super().__init__()
        self.fc_in = nn.Linear(n_qubits, n_qubits)
        self.activation = nn.Tanh()
        self.q_layer = QuantumNetworkBaseAngleClass(n_qubits=n_qubits, n_layers=n_layers,
                                                    x_scale_factor=x_scale_factor, x_scale_bias=x_scale_bias)
        self.fc_out = nn.Linear(n_qubits, n_qubits)

    def forward(self, inputs):
        out_fc_in = self.fc_in(inputs.unsqueeze(0))
        out_activation = self.activation(out_fc_in)
        out_q_layer = self.q_layer(out_activation)
        out_fc_out = self.fc_out(out_q_layer).squeeze(0).squeeze(0)
        return out_fc_out

The gradients:

gradients
        for epoch in range(self.epochs):
            u_out_arr = []
            d_u_dx_arr = []
            loss_single_batch_arr = []
            loss_single_batch_no_beta_arr = []
            for collocation_points in self.dataloader:
                self.opt.zero_grad(set_to_none=False)  # check if adding False makes any difference
                collocation_points = collocation_points.squeeze(-1)
                u_out = self.model(collocation_points)
                u_out_arr.append(u_out)  # for plotting and derivative calculation check
                d_u_dx = torch.zeros(self.n_points)
                for i in range(self.num_batches):
                    jac_func = jacobian(self.model, collocation_points[i], create_graph=True, strict=True)
                    d_u_dx[i*self.batch_size:(i+1)*self.batch_size] = torch.diag(jac_func, 0)
                d_u_dx_arr.append(d_u_dx)
                loss, loss_no_beta = self.loss_fn(collocation_points, u=u_out, d_u_dx=d_u_dx,
                                                 target_fn=self.target_fn_dict)
                loss.backward()
                loss_single_batch_arr.append(loss)
                loss_single_batch_no_beta_arr.append(loss_no_beta)

                self.opt.step()

When computing derivatives using jacobian (only the diagonal because I need each output’s derivative with respect to the corresponding input):

  • With in_out_size=1: derivatives correctly correspond to spatial derivatives
  • With in_out_size=n_qubits: derivatives don’t match expected spatial derivatives

Question: Why does increasing input/output dimensions affect the derivative computation?

The same behavior is observed when I change the quantum layer to nn.Linear or when I use torch.gradient, so I don’t think it’s due to the quantum circuit or the kind of derivative.

results for the two cases:

Thanks in advance!

Hi ziv_chen,

I’m working on a similar project and am having quite similar problems. Have you verified, that the qnode receives the right format of the data?

I’ve been trying to use the native implementation of a qml.nn.TorchLayer for a while now but always get unexpected behaviour:

  • I suppose you are using batched inputs, so your QuantumNetworkAngleClass receives a tensor containing multiple input vectors.
  • As I understand the pennylane docs and am reading your code, the qnode is written to only receive one input vector. Calling the qnode for eatch vector in the batch should be handled by the torch framewok (e.g. in this case your nn.Module implementation of QuantumNetworkBaseAngleClass with the torch quantum layer self.q_layer).
  • Now, when I had similar problems in my project, I started to read out the inputs my qnode gets and voila, torch passes the whole batch of inputs to my qnode (instead of one input at a time), resulting in the qnode returning very random values.
  • Since the returns from the qnode are batched again by the torch framwork (putting all return values in a tensor and reformatting it to have the “correct” shape), the rest of the model continues using the values as intended and does not detect any errors.

As of now, I have not found a way for torch to only call the qnode once for every input, but at least that might be where your problems come from (since if you are only using one input and one qubit, the batched input tensor is processed one at a time and thus no mismatch of the return values can mess up the result).

→ Maybe try printing the output of the qnode and the output of self.q_layer and see if theres a mismatch, as was the case with my code

(with batch_size of 2 and 8 qubits I want my qnode to receive and return a tensor with shape (8,) but torch passes a tensor of shape (2,8) and thus my qnode returns shape (2,) 8 times which is coverted into a tensor of shape (2,8) after the qnode execution.)

Good luck and I’d love to hear back!

1 Like

Hi @Timmi315,

Thanks for the thorough reply!
I have tried to do what you wrote and I have few findings and a few follow-up questions.
(TL;DR: still couldn’t get it to work…)

Findings:

  1. I used the same amount of points and qubits so I don’t have batches (4 for the example). The inputs had the shape and type of Tensor: (4,). It is very hard to understand the output of the qnode itself. That’s because when I print from within the circuit, it just says [expval(Z(0)), expval(Z(1)), expval(Z(2)), expval(Z(3))].
    I dug deeper into the code using the debugger and the return type and shape of the execute within the qnode is tuple:(4) (each element in the tuple is a tensor), but later it gets converted to a list of tensors, one tensor per measurement still within the qnode.
    When torch evaluates the qnode (using self._evaluate_qnode(inputs)) it receives a Tensor: (4,).
  2. when I use the circuit as it is with the same seed, inputs and weights I get the same output for the qnode, where the type is a list of tensors.
  3. even without batches, the results still have the same behavior as before (even if I use more points, like12).

Follow-up questions:

  1. If I would compare to your example, do you mean that my qnode returns shape of (1,) 4 times and I get (4,) after conversion or that it just does not apply in my case?
  2. do you have any Idea of how to progress from this according to what I’ve found?

Thanks!

By the way, the possibility to send batches to the qnode exists, it is called parameter broadcasting (How to execute quantum circuits in collections and batches | PennyLane Blog) and it can speed things up quite dramatically.