Giter Club home page Giter Club logo

Comments (8)

refraction-ray avatar refraction-ray commented on July 23, 2024

please attach a reproduce demo

from tensorcircuit.

Marsmmz avatar Marsmmz commented on July 23, 2024

Hi, here is a reproduce demo:

import tensorcircuit as tc
import tensorflow as tf
import numpy as np

tc.set_backend("tensorflow")
tc.set_dtype("complex64")
def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * tf.cast(c.expectation_ps(z=[i]), tf.float64)
    return -tc.backend.real(e)


def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})
    
    for i in range(n):
        circuit.rx(i,theta=params[i][0])
        circuit.ry(i,theta=params[i][1])
        circuit.rz(i,theta=params[i][2])
    
    energy = Hamiltonian(circuit, n)
    return energy

vqe_vvag = tc.backend.jit(
    tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums = (0,)), static_argnums=(1,)
)




if __name__=="__main__":
    batch = 16
    n = 8
    maxiter = 100
    params = tf.Variable(
            initial_value=tf.concat(
                [tf.random.normal(shape=[int(batch/4), n, 3], mean=0, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/2, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi*3/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr))
                ],0)
        )
    opt = tf.keras.optimizers.legacy.Adam(1e-2)
    for i in range(maxiter):
        energy, grad = vqe_vvag(params, n)
        opt.apply_gradients([(grad, params)])
        print(energy)

Thanks a lot!

from tensorcircuit.

refraction-ray avatar refraction-ray commented on July 23, 2024

Thanks for providing the demo, but I can successfully run your demo with no error, my environment info attached below

>>> tc.about()
OS info: macOS-10.15.7-x86_64-i386-64bit
Python version: 3.10.0
Numpy version: 1.24.3
Scipy version: 1.10.1
Pandas version: 2.0.3
TensorNetwork version: 0.5.0
Cotengra version: 0.6.0
TensorFlow version: 2.13.0
TensorFlow GPU: []
TensorFlow CUDA infos: {'is_cuda_build': False, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.14
Jax installation doesn't support GPU
JaxLib version: 0.4.14
PyTorch version: 2.0.1
PyTorch GPU support: False
PyTorch GPUs: []
Cupy is not installed
Qiskit version: 0.45.1
Cirq version: 1.2.0
TensorCircuit version 0.12.0

from tensorcircuit.

refraction-ray avatar refraction-ray commented on July 23, 2024

Ah, you mean the warning, I indeed see the warning but I believe it doesn't affect the results. I will further investigate whether the warning has negative effect on jit or whether we can get rid of the warning.

Have checked now! The warning is not related to jit but to vmap. If we use value_and_grad instead of vvag, the warning is gone. The reason for the warning is that there is no vectorized implementation for QR in tensorflow.

from tensorcircuit.

refraction-ray avatar refraction-ray commented on July 23, 2024

If you feel tf is not fast enough, you can always try the following snippet for your actual circuit and hyperparameters, to determine which backend is more suitable (tf vs. jax)

import tensorcircuit as tc
import numpy as np
import time

tc.set_dtype("complex64")


def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * c.expectation_ps(z=[i])
    return -tc.backend.real(e)


def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})

    for i in range(n):
        circuit.rx(i, theta=params[i][0])
        circuit.ry(i, theta=params[i][1])
        circuit.rz(i, theta=params[i][2])
    for i in range(n-1):
        circuit.cx(i, i+1)

    energy = Hamiltonian(circuit, n)
    return energy


if __name__ == "__main__":
    batch = 16
    n = 16
    maxiter = 100
    params0 = np.random.uniform(size=[batch, n, 3])

    for b in ["tensorflow", "jax"]:
        with tc.runtime_backend(b):
            vqe_vvag = tc.backend.jit(
                tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums=(0,)),
                static_argnums=(1,),
            )
            print("benchmarking backend: %s" % b)
            time0 = time.time()
            params = tc.backend.convert_to_tensor(params0)
            energy, grad = vqe_vvag(params, n)
            print(energy, grad)
            print("jit time", time.time() - time0)
            time0 = time.time()
            for _ in range(5):
                energy, grad = vqe_vvag(params, n)
            print("running time", (time.time() - time0) / 5)

from tensorcircuit.

Marsmmz avatar Marsmmz commented on July 23, 2024

Aha, I see. Thanks a lot!
So it seems that we can't use vvag for speeding up with tf as backend.

from tensorcircuit.

Marsmmz avatar Marsmmz commented on July 23, 2024

I will close this issue, many thanks!

from tensorcircuit.

refraction-ray avatar refraction-ray commented on July 23, 2024

Aha, I see. Thanks a lot! So it seems that we can't use vvag for speeding up with tf as backend.

For this point, I dont know. Maybe you can have some microbenchmarks on vvag over batch vs. naive for loop with tf backend. It is also possible that other operations are vectorized which may still be more efficient that a for loop.

from tensorcircuit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.