Comments (5)
Since Jax doesn't modify data structures, mjx_data.at[.].set(.) returns a new copy rather than modifying in place. Perhaps try weaving six
mjx_data =
within the lines
mjx_data.qpos.at[0].set(r0[0]), mjx_data.qpos.at[1].set(r0[1]), mjx_data.qpos.at[2].set(r0[2])
mjx_data.qvel.at[0].set(v0[0]), mjx_data.qvel.at[1].set(v0[1]), mjx_data.qvel.at[2].set(v0[2])
from mujoco.
Hi, this is an interesting problem! I believe that you're not using the gradient in the right way. The simulation function mjx.step
implements
I'm suspecting what's happening for you is that since the simulator is discrete, the derivative of the system wrt time is almost everywhere 0, since the system only changes with time at measure 0 instances.
If your goal is simply to sanity-check the gradients, perhaps try with state and action rather than time?
from mujoco.
Thank you for your answer.
I understand now why it is not possible to get the gradient with respect to time. But what I really want is to get the gradient with respect to r0 (or v0). I want to do this since I want to find the optimal initial conditions r0 and v0 to imitate a movement via a gradient based optimization algorithm. Thus I actually want the following:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mujoco
from mujoco import mjx
XML=r'''
<mujoco>
<worldbody>
<body pos="5.425 0 1">
<freejoint/>
<geom size=".12" mass="0.6" type="sphere"/>
</body>
</worldbody>
</mujoco>
'''
def simulation(r0, v0, T):
'''calculates the trajectory of the ball in 3D world coordinates
r0: np.array, shape (3,), initial position of the ball
v0: np.array, shape (3,), initial velocity of the ball
T: np.array, shape (1,), end time
------
Returns:
coords3D: np.array, shape (T, 3), trajectory of the ball in 3D world coordinates
'''
model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.put_model(model)
data = mujoco.MjData(model)
# data.qpos[0], data.qpos[1], data.qpos[2] = r0[0], r0[1], r0[2]
# data.qvel[0], data.qvel[1], data.qvel[2] = v0[0], v0[1], v0[2]
mjx_data = mjx.put_data(model, data)
mjx_data.qpos.at[0].set(r0[0]), mjx_data.qpos.at[1].set(r0[1]), mjx_data.qpos.at[2].set(r0[2])
mjx_data.qvel.at[0].set(v0[0]), mjx_data.qvel.at[1].set(v0[1]), mjx_data.qvel.at[2].set(v0[2])
jit_step = jax.jit(mjx.step)
duration = T
while mjx_data.time <= duration:
mjx_data = jit_step(mjx_model, mjx_data)
return mjx_data.qpos[2]
v0 = jnp.array([0., 0., 0.], dtype=jnp.float64)
T = jnp.array(2, dtype=jnp.float64)
wrapper_fn = lambda r0: simulation(r0, v0, T)
r0 = jnp.array([5.425, 0, 10], dtype=jnp.float64)
pos = wrapper_fn(r0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T
grad_fn = jax.grad(wrapper_fn)
grad = grad_fn(r0)
print(grad) # should be 1
The prints are:
-18.639619999999876
[0. 0. 0.]
The correct solution for the gradient would be [0, 0, 1] as far as I know. Thus, it seems like the derivative does not work correctly too, probably also due to the discretization. Any idea how to solve this problem? I do need to get the gradients dz(t)/dr0 and dz(t)/dv0.
from mujoco.
Thank you. Now I am able to obtain the correct gradient.
If anyone else is interested in such a problem, here is a working code example:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import mujoco
from mujoco import mjx
def simulation(r0, v0, T):
'''calculates the trajectory of the ball in 3D world coordinates
r0: np.array, shape (3,), initial position of the ball
v0: np.array, shape (3,), initial velocity of the ball
T: np.array, shape (1,), end time
------
Returns:
coords3D: np.array, shape (T, 3), trajectory of the ball in 3D world coordinates
'''
model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.put_model(model)
data = mujoco.MjData(model)
mjx_data = mjx.put_data(model, data)
qpos = mjx_data.qpos.at[0:3].set(r0)
qvel = mjx_data.qvel.at[0:3].set(v0)
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
jit_step = jax.jit(mjx.step)
duration = T
while mjx_data.time <= duration:
mjx_data = jit_step(mjx_model, mjx_data)
#mjx_data = jax.lax.while_loop(lambda x: x.time <= duration, lambda x: jit_step(mjx_model, x), mjx_data)
return mjx_data.qpos[2]
v0 = jnp.array([0., 0., 0.], dtype=jnp.float64)
T = jnp.array(0.5, dtype=jnp.float64)
wrapper_fn = lambda r0: simulation(r0, v0, T)
r0 = jnp.array([5.425, 0, 10], dtype=jnp.float64)
pos = wrapper_fn(r0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T
# d/dr0
grad_fn = jax.grad(wrapper_fn)
grad = grad_fn(r0)
print(grad) # should be 1
wrapper_fn2 = lambda v0: simulation(r0, v0, T)
pos = wrapper_fn2(v0)
print(pos) # should be r0[2] - 0.5 * g * T**2 + v0[2] * T
# d/dv0
grad_fn2 = jax.grad(wrapper_fn2)
grad = grad_fn2(v0)
print(grad) # should be T
The outputs are:
8.768844999999992
[0. 0. 1.]
8.768844999999992
[0. 0. 0.5]
from mujoco.
Thanks @Andrew-Luo1 !
from mujoco.
Related Issues (20)
- Registering a New Custom SDF
- Linking Python Bindings Code HOT 2
- Jittery/unstable contacts for a simple STL file of a box resting on a plane? HOT 2
- Actuator force visualization HOT 5
- mj_forward modifies jacobian when called after mj_step HOT 3
- Bouncing in MJX HOT 1
- I have a question about the friction force experiment. HOT 2
- mjx.put_data() and mjx.put_model() takes lot of time to resolve HOT 2
- Improved MuJoCo Simulation Scaling HOT 1
- Camera rendering in MuJoCo - How to take a picture from a simulated camera and still keeping the GUI active
- MJX Documentation API Update HOT 1
- how to get flexible body position and velocity
- Possible to use different fluid force configuration among different links? HOT 2
- Problem using xpos to obtain the location of the xml file loaded model
- Manipulating Individual Joint Stiffness of a Elastic Cable HOT 1
- Issue with Camera Visibility of Dynamically Merged XML Models in MuJoCo HOT 1
- Issue with STL placements in new Model
- Support for Rendering in MJX for Simulated Camera Reinforcement Learning HOT 5
- `.skn` file not working with replicated bodies (dmcontrol rodent model) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mujoco.