In [1]:
!pip install pennylane
!pip install jax

Collecting pennylane
  Downloading pennylane-0.43.0-py3-none-any.whl.metadata (11 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray==0.8.0 (from pennylane)
  Downloading autoray-0.8.0-py3-none-any.whl.metadata (6.1 kB)
Collecting pennylane-lightning>=0.43 (from pennylane)
  Downloading pennylane_lightning-0.43.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.43->pennylane)
  Downloading scipy_openblas32-0.3.30.0.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.1/57.1

In [2]:
import pennylane as qml

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(param):
    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    return qml.expval(qml.PauliZ(1))

In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad

# data
key = random.PRNGKey(42)
x = random.normal(key, (10, 3))
y = jnp.array([[1.0], [0.0], [1.0], [0.0], [1.0],
               [0.0], [1.0], [0.0], [1.0], [0.0]])

# init weights
def init_params(key):
    k1, k2 = random.split(key)
    W1 = random.normal(k1, (3, 4))
    b1 = jnp.zeros(4)
    W2 = random.normal(k2, (4, 1))
    b2 = jnp.zeros(1)
    return (W1, b1, W2, b2)

# model
def forward(params, x):
    W1, b1, W2, b2 = params
    h = jnp.tanh(x @ W1 + b1)
    h = jax.nn.sigmoid(h @ W2 + circuit(b2 * jnp.pi))
    return h


# loss
def loss_fn(params, x, y):
    preds = forward(params, x)
    return jnp.mean((preds - y) ** 2)

# training step
@jax.jit
def update(params, x, y, lr=0.1):
    grads = grad(loss_fn)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

# train
params = init_params(key)
for i in range(200):
    params = update(params, x, y)
    if i % 5 == 0:
        print(i, loss_fn(params, x, y))


0 0.30217844
5 0.29236612
10 0.2849971
15 0.27915093
20 0.27423224
25 0.26988873
30 0.26590854
35 0.26215717
40 0.25854278
45 0.25499853
50 0.2514727
55 0.24792306
60 0.2443141
65 0.24061473
70 0.23679791
75 0.23283985
80 0.22872034
85 0.2244227
90 0.21993451
95 0.21524821
100 0.21036112
105 0.20527649
110 0.2000031
115 0.19455588
120 0.188955
125 0.18322562
130 0.17739694
135 0.17150114
140 0.16557209
145 0.15964386
150 0.1537499
155 0.14792159
160 0.14218764
165 0.13657312
170 0.13109954
175 0.12578449
180 0.12064164
185 0.11568117
190 0.110910036
195 0.10633228


In [None]:
forward(params, x[0,:])

Array([0.8505869], dtype=float32)