qp.labs.phox¶
Phase optimization with JAX (PHOX)¶
|
Configuration data for an IQP circuit simulation. |
|
Hyperparameters for Maximum Mean Discrepancy (MMD) loss calculation. |
|
Factory that returns a flexible pure function for computing expectation values. |
|
Compute expectation value for the Bitflip noise model. |
|
Estimate MMD loss using configuration dataclasses. |
|
Compute a robust median-distance heuristic for RBF bandwidth selection. |
|
Main training function. |
|
Generator that yields training results in batches of size 'unroll_steps'. |
|
Configuration options for training. |
|
Container for final training results. |
|
Result from a single batch (unrolled chunk) of training steps. |
Circuit construction utilities¶
|
Generates gates based on nearest-neighbor interactions on a 2D lattice. |
|
Generates a gate dictionary for the Phox simulator containing all gates whose generators have Pauli weight less or equal to max_weight. |
|
Generates a dictionary of random gates. |
|
Generates a batch of Pauli observables represented as integers (I=0, X=1, Y=2, Z=3). |
Workflow¶
pennylane.labs.phox provides a compact toolkit for constructing and
simulating phase optimization circuits with JAX. The usual workflow is:
Use helpers in
pennylane.labs.phox.utilsto assemble gates and observables.Configure the circuit with
CircuitConfig.Build an expectation-value function with
build_expval_func()and evaluate it for different parameter sets.
import jax
from pennylane.labs.phox import (
CircuitConfig,
build_expval_func,
create_lattice_gates,
generate_pauli_observables,
)
n_rows, n_cols = 3, 3
n_qubits = n_rows * n_cols
gates = create_lattice_gates(n_rows, n_cols, distance=1, max_weight=2)
observables = generate_pauli_observables(n_qubits, orders=[2], bases=["Z"])
key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(len(gates),))
config = CircuitConfig(
gates=gates,
observables=observables,
n_samples=4000,
key=key,
n_qubits=n_qubits,
)
expval_fn = jax.jit(build_expval_func(config))
expvals, std_errs = expval_fn(params)
Training¶
Below is a small training loop that minimizes the sum of all two-body Z
correlators on the same 3x3 lattice. The loss function reuses the
compiled expval_fn from above.
import jax.numpy as jnp
from pennylane.labs.phox import TrainingOptions, train
def loss_fn(current_params):
expvals, _ = expval_fn(current_params)
return jnp.sum(expvals)
result = train(
optimizer="Adam",
loss=loss_fn,
stepsize=0.05,
n_iters=200,
loss_kwargs={"params": params},
options=TrainingOptions(unroll_steps=10, random_state=1234),
)
print("Final loss:", float(result.losses[-1]))
print("Optimized parameters:", result.final_params)
Maximum Mean Discrepancy (MMD) Loss¶
To train a circuit to reproduce a target probability distribution (e.g., a dataset
of bitstrings), you can use the built-in MMD loss utilities. This integrates
seamlessly with the train function using the MMDConfig
dataclass.
import numpy as np
from pennylane.labs.phox import MMDConfig, mmd_loss, median_heuristic
np.random.seed(42)
target_data = np.random.binomial(1, 0.5, size=(500, n_qubits))
bandwidth = median_heuristic(target_data)
mmd_config = MMDConfig(bandwidth=bandwidth, n_ops=100)
loss_kwargs = {
"params": params,
"circuit_config": config,
"mmd_config": mmd_config,
"target_data": target_data,
}
mmd_result = train(
optimizer="Adam",
loss=mmd_loss,
stepsize=0.01,
n_iters=100,
loss_kwargs=loss_kwargs,
options=TrainingOptions(unroll_steps=10)
)
print("Final MMD loss:", float(mmd_result.losses[-1]))