Giter Club home page Giter Club logo

Comments (5)

Andrew-Luo1 avatar Andrew-Luo1 commented on June 10, 2024 1

Since Jax doesn't modify data structures, mjx_data.at[.].set(.) returns a new copy rather than modifying in place. Perhaps try weaving six
mjx_data =
within the lines

mjx_data.qpos.at[0].set(r0[0]), mjx_data.qpos.at[1].set(r0[1]), mjx_data.qpos.at[2].set(r0[2])
mjx_data.qvel.at[0].set(v0[0]), mjx_data.qvel.at[1].set(v0[1]), mjx_data.qvel.at[2].set(v0[2]) 

from mujoco.

Andrew-Luo1 avatar Andrew-Luo1 commented on June 10, 2024

Hi, this is an interesting problem! I believe that you're not using the gradient in the right way. The simulation function mjx.step implements $x_{k+1} = f(x_k, a_k)$ for state x and action a. I've checked the correctness of the gradients with respect to x and a, but I haven't tried taking the derivative wrt time.

I'm suspecting what's happening for you is that since the simulator is discrete, the derivative of the system wrt time is almost everywhere 0, since the system only changes with time at measure 0 instances.

If your goal is simply to sanity-check the gradients, perhaps try with state and action rather than time?

from mujoco.

KieDani avatar KieDani commented on June 10, 2024

Thank you for your answer.
I understand now why it is not possible to get the gradient with respect to time. But what I really want is to get the gradient with respect to r0 (or v0). I want to do this since I want to find the optimal initial conditions r0 and v0 to imitate a movement via a gradient based optimization algorithm. Thus I actually want the following:

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mujoco
from mujoco import mjx

XML=r'''
<mujoco>
  <worldbody>
    <body pos="5.425 0 1">
      <freejoint/>
      <geom size=".12" mass="0.6" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
'''

def simulation(r0, v0, T):
    '''calculates the trajectory of the ball in 3D world coordinates
    r0: np.array, shape (3,), initial position of the ball
    v0: np.array, shape (3,), initial velocity of the ball
    T: np.array, shape (1,), end time
    ------
    Returns:
        coords3D: np.array, shape (T, 3), trajectory of the ball in 3D world coordinates
    '''
    model = mujoco.MjModel.from_xml_string(XML)
    mjx_model = mjx.put_model(model)
    data = mujoco.MjData(model)
    # data.qpos[0], data.qpos[1], data.qpos[2] = r0[0], r0[1], r0[2]
    # data.qvel[0], data.qvel[1], data.qvel[2] = v0[0], v0[1], v0[2]
    mjx_data = mjx.put_data(model, data)
    mjx_data.qpos.at[0].set(r0[0]), mjx_data.qpos.at[1].set(r0[1]), mjx_data.qpos.at[2].set(r0[2])
    mjx_data.qvel.at[0].set(v0[0]), mjx_data.qvel.at[1].set(v0[1]), mjx_data.qvel.at[2].set(v0[2])
    jit_step = jax.jit(mjx.step)
    duration = T
    while mjx_data.time <= duration:
      mjx_data = jit_step(mjx_model, mjx_data)
    return mjx_data.qpos[2]

v0 = jnp.array([0., 0., 0.], dtype=jnp.float64)
T = jnp.array(2, dtype=jnp.float64)
wrapper_fn = lambda r0: simulation(r0, v0, T)
r0 = jnp.array([5.425, 0, 10], dtype=jnp.float64)
pos = wrapper_fn(r0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T


grad_fn = jax.grad(wrapper_fn)
grad = grad_fn(r0)
print(grad) # should be 1

The prints are:

-18.639619999999876
[0. 0. 0.]

The correct solution for the gradient would be [0, 0, 1] as far as I know. Thus, it seems like the derivative does not work correctly too, probably also due to the discretization. Any idea how to solve this problem? I do need to get the gradients dz(t)/dr0 and dz(t)/dv0.

from mujoco.

KieDani avatar KieDani commented on June 10, 2024

Thank you. Now I am able to obtain the correct gradient.
If anyone else is interested in such a problem, here is a working code example:

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mujoco
from mujoco import mjx

def simulation(r0, v0, T):
    '''calculates the trajectory of the ball in 3D world coordinates
    r0: np.array, shape (3,), initial position of the ball
    v0: np.array, shape (3,), initial velocity of the ball
    T: np.array, shape (1,), end time
    ------
    Returns:
        coords3D: np.array, shape (T, 3), trajectory of the ball in 3D world coordinates
    '''
    model = mujoco.MjModel.from_xml_string(XML)
    mjx_model = mjx.put_model(model)
    data = mujoco.MjData(model)
    mjx_data = mjx.put_data(model, data)
    qpos = mjx_data.qpos.at[0:3].set(r0)
    qvel = mjx_data.qvel.at[0:3].set(v0)
    mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
    jit_step = jax.jit(mjx.step)
    duration = T
    while mjx_data.time <= duration:
        mjx_data = jit_step(mjx_model, mjx_data)
    #mjx_data = jax.lax.while_loop(lambda x: x.time <= duration, lambda x: jit_step(mjx_model, x), mjx_data)
    return mjx_data.qpos[2]

v0 = jnp.array([0., 0., 0.], dtype=jnp.float64)
T = jnp.array(0.5, dtype=jnp.float64)
wrapper_fn = lambda r0: simulation(r0, v0, T)
r0 = jnp.array([5.425, 0, 10], dtype=jnp.float64)
pos = wrapper_fn(r0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T

# d/dr0
grad_fn = jax.grad(wrapper_fn)
grad = grad_fn(r0)
print(grad) # should be 1

wrapper_fn2 = lambda v0: simulation(r0, v0, T)
pos = wrapper_fn2(v0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T

# d/dv0
grad_fn2 = jax.grad(wrapper_fn2)
grad = grad_fn2(v0)
print(grad) # should be T

The outputs are:

8.768844999999992
[0. 0. 1.]
8.768844999999992
[0.  0.  0.5]

from mujoco.

yuvaltassa avatar yuvaltassa commented on June 10, 2024

Thanks @Andrew-Luo1 !

from mujoco.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.