In [5]:
!pip install pennylane
!pip install jax==0.5.3



In [21]:
import pennylane.numpy as np
import pennylane as qml
import jax

dev = qml.device("default.mixed", wires=2)

@qml.qnode(dev)
def circuit(param):
    qml.RX(param, wires=0)
    qml.DepolarizingChannel(p=0.1, wires=1)
    qml.CNOT(wires=[0, 1])

    return qml.expval(qml.PauliZ(1))

In [22]:
print("\nGradient Descent")
print("---------------")


grad_circuit = jax.grad(circuit)

param = 0.42

print(f"Initial param: {param:0.3f}")
print(f"Initial cost: {circuit(param):0.3f}")

for _ in range(100):
    param -= grad_circuit(param)

print(f"Tuned param: {param:0.3f}")
print(f"Tuned cost: {circuit(param):0.3f}")



Gradient Descent
---------------
Initial param: 0.420
Initial cost: 0.913
Tuned param: 3.142
Tuned cost: -1.000


In [23]:
import jax
import jax.numpy as jnp
from jax import random, grad

# data
key = random.PRNGKey(0)
x = random.normal(key, (10, 3))
y = jnp.array([[1.0], [0.0], [1.0], [0.0], [1.0],
               [0.0], [1.0], [0.0], [1.0], [0.0]])

# init weights
def init_params(key):
    k1, k2 = random.split(key)
    W1 = random.normal(k1, (3, 4))
    b1 = jnp.zeros(4)
    W2 = random.normal(k2, (4, 1))
    b2 = jnp.zeros(1)
    return (W1, b1, W2, b2)

# model
def forward(params, x):
    W1, b1, W2, b2 = params
    h = jnp.tanh(x @ W1 + b1)
    h = jax.nn.sigmoid(h @ W2 + b2) * jnp.pi
    return circuit(jnp.mean(h))


# loss
def loss_fn(params, x, y):
    preds = forward(params, x)
    return jnp.mean((preds - y) ** 2)

# training step
@jax.jit
def update(params, x, y, lr=0.1):
    grads = grad(loss_fn)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

# train
params = init_params(key)
for i in range(200):
    params = update(params, x, y)
    if i % 50 == 0:
        print(i, loss_fn(params, x, y))


0 0.4422666
50 0.25000006
100 0.25
150 0.25
