In [3]:
!pip install pennylane
!pip install jax



In [4]:
import pennylane.numpy as np
import pennylane as qml
import jax
import jax.numpy as jnp
from pennylane.devices import preprocess as pp

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev)
def circuit(param):

    qml.RX(param, wires=0)
    qml.CNOT(wires=[0, 1])

    return qml.expval(qml.PauliZ(1))


print("\nGradient Descent")
print("---------------")

grad_circuit = jax.grad(circuit)
print(f"grad_circuit(jnp.pi / 2): {grad_circuit(jnp.pi / 2):0.3f}")

param = 0.42

print(f"Initial param: {param:0.3f}")
print(f"Initial cost: {circuit(param):0.3f}")

for _ in range(100):
    param -= grad_circuit(param)

print(f"Tuned param: {param:0.3f}")
print(f"Tuned cost: {circuit(param):0.3f}")



Gradient Descent
---------------
grad_circuit(jnp.pi / 2): -1.000
Initial param: 0.420
Initial cost: 0.913
Tuned param: 3.142
Tuned cost: -1.000
