In [1]:
!pip install pennylane
!pip install jax

Collecting pennylane
  Downloading pennylane-0.43.0-py3-none-any.whl.metadata (11 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray==0.8.0 (from pennylane)
  Downloading autoray-0.8.0-py3-none-any.whl.metadata (6.1 kB)
Collecting pennylane-lightning>=0.43 (from pennylane)
  Downloading pennylane_lightning-0.43.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.43->pennylane)
  Downloading scipy_openblas32-0.3.30.0.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1

In [6]:
import pennylane.numpy as np
import pennylane as qml
import jax
import jax.numpy as jnp
from pennylane.devices import preprocess as pp

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, diff_method="adjoint", interface="jax")
def circuit(param):

    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    return qml.expval(qml.PauliZ(1))

print("\nGradient Descent")
print("---------------")

grad_circuit = jax.grad(circuit)
print(f"grad_circuit(jnp.pi / 2): {grad_circuit(jnp.pi / 2):0.3f}")

param = 0.42

print(f"Initial param: {param:0.3f}")
print(f"Initial cost: {circuit(param):0.3f}")

for _ in range(100):
    param -= grad_circuit(param)

print(f"Tuned param: {param:0.3f}")
print(f"Tuned cost: {circuit(param):0.3f}")



Gradient Descent
---------------
grad_circuit(jnp.pi / 2): -1.000
Initial param: 0.420
Initial cost: 0.913
Tuned param: 3.142
Tuned cost: -1.000


In [8]:
dev.capabilities