qp.labs.phox

Phase optimization with JAX (PHOX)

CircuitConfig(gates, observables, n_samples, ...)

Configuration data for an IQP circuit simulation.

MMDConfig(bandwidth, n_ops[, wires, ...])

Hyperparameters for Maximum Mean Discrepancy (MMD) loss calculation.

build_expval_func(config)

Factory that returns a flexible pure function for computing expectation values.

bitflip_expval(generators, params, ops)

Compute expectation value for the Bitflip noise model.

mmd_loss(params, circuit_config, mmd_config, ...)

Estimate MMD loss using configuration dataclasses.

median_heuristic(samples)

Compute a robust median-distance heuristic for RBF bandwidth selection.

train(optimizer, loss, stepsize, n_iters, ...)

Main training function.

training_iterator(optimizer, loss, stepsize, ...)

Generator that yields training results in batches of size 'unroll_steps'.

TrainingOptions([unroll_steps, val_kwargs, ...])

Configuration options for training.

TrainingResult(final_params, losses, ...)

Container for final training results.

BatchResult(params, state, key, key_val, ...)

Result from a single batch (unrolled chunk) of training steps.

Circuit construction utilities

create_lattice_gates(rows, cols[, distance, ...])

Generates gates based on nearest-neighbor interactions on a 2D lattice.

create_local_gates(n_qubits[, max_weight])

Generates a gate dictionary for the Phox simulator containing all gates whose generators have Pauli weight less or equal to max_weight.

create_random_gates(n_qubits, n_gates[, ...])

Generates a dictionary of random gates.

generate_pauli_observables(n_qubits[, ...])

Generates a batch of Pauli observables represented as integers (I=0, X=1, Y=2, Z=3).

Workflow

pennylane.labs.phox provides a compact toolkit for constructing and simulating phase optimization circuits with JAX. The usual workflow is:

  1. Use helpers in pennylane.labs.phox.utils to assemble gates and observables.

  2. Configure the circuit with CircuitConfig.

  3. Build an expectation-value function with build_expval_func() and evaluate it for different parameter sets.

import jax

from pennylane.labs.phox import (
    CircuitConfig,
    build_expval_func,
    create_lattice_gates,
    generate_pauli_observables,
)

n_rows, n_cols = 3, 3
n_qubits = n_rows * n_cols

gates = create_lattice_gates(n_rows, n_cols, distance=1, max_weight=2)
observables = generate_pauli_observables(n_qubits, orders=[2], bases=["Z"])

key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(len(gates),))

config = CircuitConfig(
    gates=gates,
    observables=observables,
    n_samples=4000,
    key=key,
    n_qubits=n_qubits,
)

expval_fn = jax.jit(build_expval_func(config))
expvals, std_errs = expval_fn(params)

Training

Below is a small training loop that minimizes the sum of all two-body Z correlators on the same 3x3 lattice. The loss function reuses the compiled expval_fn from above.

import jax.numpy as jnp

from pennylane.labs.phox import TrainingOptions, train

def loss_fn(current_params):
    expvals, _ = expval_fn(current_params)
    return jnp.sum(expvals)

result = train(
    optimizer="Adam",
    loss=loss_fn,
    stepsize=0.05,
    n_iters=200,
    loss_kwargs={"params": params},
    options=TrainingOptions(unroll_steps=10, random_state=1234),
)

print("Final loss:", float(result.losses[-1]))
print("Optimized parameters:", result.final_params)

Maximum Mean Discrepancy (MMD) Loss

To train a circuit to reproduce a target probability distribution (e.g., a dataset of bitstrings), you can use the built-in MMD loss utilities. This integrates seamlessly with the train function using the MMDConfig dataclass.

import numpy as np
from pennylane.labs.phox import MMDConfig, mmd_loss, median_heuristic

np.random.seed(42)
target_data = np.random.binomial(1, 0.5, size=(500, n_qubits))

bandwidth = median_heuristic(target_data)
mmd_config = MMDConfig(bandwidth=bandwidth, n_ops=100)

loss_kwargs = {
    "params": params,
    "circuit_config": config,
    "mmd_config": mmd_config,
    "target_data": target_data,
}

mmd_result = train(
    optimizer="Adam",
    loss=mmd_loss,
    stepsize=0.01,
    n_iters=100,
    loss_kwargs=loss_kwargs,
    options=TrainingOptions(unroll_steps=10)
)

print("Final MMD loss:", float(mmd_result.losses[-1]))