Skip to content

Commit

Permalink
Merge branch 'feat/mjx_comparison' into feat/compression
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 5, 2024
2 parents c9246ce + 2f513a1 commit f2d60e0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
29 changes: 21 additions & 8 deletions examples/03_pinocchio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
This script compares the performance between sequential calls of CasADi
and its converted JAX counterpart. The comparison is made between a sequential
evaluation using CasADi and a vectorized evaluation using JAX for the forward
kinematics of a Panda robot's end-effector.
Disclaimer: Performance results may vary depending on hardware configuration and system load.
For optimal performance, consider installing CUDA-enabled JAX.
"""

import timeit
import jax
import casadi as ca
Expand Down Expand Up @@ -31,14 +41,17 @@

jax_fn = convert(fk, compile=True)


# Function to generate random inputs
def generate_random_inputs(N):
return np.random.rand(N, model.nq)


# Casadi: Sequential Evaluation
def casadi_sequential_evaluation(q_vals):
return [fk(q) for q in q_vals]


# JAX: Vectorized Evaluation
jax_fn_vectorized = jax.jit(jax.vmap(jax_fn)) # Vectorize the function

Expand All @@ -48,12 +61,12 @@ def casadi_sequential_evaluation(q_vals):
jax_q_vals_test = jnp.array(q_vals_test).reshape(N_test, model.nq, 1) # Create a batch of 1000

print(f"Casadi sequential evaluation ({N_test} times):")
casadi_results_test = np.array(casadi_sequential_evaluation(q_vals_test))[:,:,0]
casadi_results_test = np.array(casadi_sequential_evaluation(q_vals_test))[:, :, 0]
print(f"First result: {casadi_results_test[0]}")
print(f"Last result: {casadi_results_test[-1]}")
print(f"Shape: {casadi_results_test.shape}")
print(f"\nJAX vectorized evaluation ({N_test} times):")
jax_results_test = np.array(jax_fn_vectorized(jax_q_vals_test))[0,:,:,0]
jax_results_test = np.array(jax_fn_vectorized(jax_q_vals_test))[0, :, :, 0]
print(f"First result: {jax_results_test[0]}")
print(f"Last result: {jax_results_test[-1]}")
print(f"Shape: {jax_results_test.shape}")
Expand All @@ -73,24 +86,24 @@ def casadi_sequential_evaluation(q_vals):
# call with same dimensions as target input to avoid re-compiling
q_vals = generate_random_inputs(N)
jax_q_vals = jnp.array(q_vals).reshape(N, model.nq, 1)
np.array(jax_fn_vectorized(jax_q_vals))[0,:,:,0]
np.array(jax_fn_vectorized(jax_q_vals))[0, :, :, 0]

# Generate new random inputs for performance comparison
q_vals = generate_random_inputs(N)
jax_q_vals = jnp.array(q_vals).reshape(N, model.nq, 1)

print(f"Casadi sequential evaluation ({N} times):")
casadi_time = timeit.timeit(lambda: np.array(casadi_sequential_evaluation(q_vals))[:,:,0], number=1)
casadi_time = timeit.timeit(lambda: np.array(casadi_sequential_evaluation(q_vals))[:, :, 0], number=1)
print(f"Time: {casadi_time:.4f} seconds")

print(f"\nJAX vectorized evaluation ({N} times):")
jax_time = timeit.timeit(lambda: np.array(jax_fn_vectorized(jax_q_vals))[0,:,:,0], number=1)
jax_time = timeit.timeit(lambda: np.array(jax_fn_vectorized(jax_q_vals))[0, :, :, 0], number=1)
print(f"Time: {jax_time:.4f} seconds")

print(f"\nSpeedup factor: {casadi_time / jax_time:.2f}x")

# Verify results
print("\nVerifying performance test results:")
casadi_results = np.array(casadi_sequential_evaluation(q_vals[:100]))[:,:,0]
jax_results = np.array(jax_fn_vectorized(jax_q_vals[:100]))[0,:,:,0]
print("First 100 results match:", np.allclose(casadi_results, jax_results, atol=1e-6))
casadi_results = np.array(casadi_sequential_evaluation(q_vals[:100]))[:, :, 0]
jax_results = np.array(jax_fn_vectorized(jax_q_vals[:100]))[0, :, :, 0]
print("First 100 results match:", np.allclose(casadi_results, jax_results, atol=1e-6))
43 changes: 29 additions & 14 deletions examples/04_mjx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
This script compares the performance of forward kinematics computations using Casadi, JAX, and MuJoCo XLA (MJX).
The comparison is done for the IIWA14 robot model.
Disclaimer: Performance results may vary depending on hardware configuration and system load.
For optimal performance, consider installing CUDA-enabled JAX.
MuJoCo XLA (MJX) documentation: https://mujoco.readthedocs.io/en/stable/mjx.html
Pinocchio documentation: https://github.com/stack-of-tasks/pinocchio
"""

import timeit
import jax
import casadi as ca
Expand Down Expand Up @@ -30,19 +41,20 @@
omf = cdata.oMf[model.getFrameId("link7")]
fk = ca.Function("fk", [q], [omf.translation])

# translate the casadi function to jax
# print(translate(fk, add_import=True, add_jit=True))

# convert the casadi function to jax
jax_fn = convert(fk, compile=True)


# Function to generate random inputs
def generate_random_inputs(N):
return np.random.rand(N, model.nq)


# Casadi: Sequential Evaluation
def casadi_sequential_evaluation(q_vals):
return [fk(q) for q in q_vals]


# JAX: Vectorized Evaluation
jax_fn_vectorized = jax.jit(jax.vmap(jax_fn)) # Vectorize the function

Expand All @@ -52,14 +64,17 @@ def casadi_sequential_evaluation(q_vals):
mj_data = mujoco.MjData(mj_model)
mjx_model = mjx.put_model(mj_model)


@jax.jit
def mjx_fk(joint_pos):
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx_data.replace(qpos=joint_pos)
mjx_data = mjx.fwd_position(mjx_model, mjx_data)
mjx_data = mjx.kinematics(mjx_model, mjx_data)
return mjx_data.xpos[-1] # Assuming the last body is the end-effector

mjx_fn_vectorized = jax.jit(jax.vmap(mjx_fk)) # Corrected: JIT-compiled vectorized function

# Corrected: JIT-compiled vectorized function
mjx_fn_vectorized = jax.jit(jax.vmap(mjx_fk))

# Evaluate the function performance for a batch
N_test = 100 # Small number for initial test
Expand All @@ -68,13 +83,13 @@ def mjx_fk(joint_pos):
mjx_q_vals_test = jnp.array(q_vals_test)

print(f"Casadi evaluation (batch of {N_test}):")
casadi_results_test = np.array(casadi_sequential_evaluation(q_vals_test))[:,:,0]
casadi_results_test = np.array(casadi_sequential_evaluation(q_vals_test))[:, :, 0]
print(f"First result: {casadi_results_test[0]}")
print(f"Last result: {casadi_results_test[-1]}")
print(f"Shape: {casadi_results_test.shape}")

print(f"\nJAX evaluation (batch of {N_test}):")
jax_results_test = np.array(jax_fn_vectorized(jax_q_vals_test))[0,:,:,0]
jax_results_test = np.array(jax_fn_vectorized(jax_q_vals_test))[0, :, :, 0]
print(f"First result: {jax_results_test[0]}")
print(f"Last result: {jax_results_test[-1]}")
print(f"Shape: {jax_results_test.shape}")
Expand All @@ -96,23 +111,23 @@ def mjx_fk(joint_pos):

# call with same dimensions as target input to avoid re-compiling
q_vals = generate_random_inputs(N)

jax_q_vals = jnp.array(q_vals).reshape(N, model.nq, 1)
mjx_q_vals = jnp.array(q_vals)
np.array(jax_fn_vectorized(jax_q_vals))
np.array(mjx_fn_vectorized(mjx_q_vals))

_ = jax_fn_vectorized(jax_q_vals)
_ = mjx_fn_vectorized(mjx_q_vals)
# Generate new random inputs for performance comparison
q_vals = generate_random_inputs(N)
jax_q_vals = jnp.array(q_vals).reshape(N, model.nq, 1)
mjx_q_vals = jnp.array(q_vals)


print(f"Casadi sequential evaluation ({N} times):")
casadi_time = timeit.timeit(lambda: np.array(casadi_sequential_evaluation(q_vals))[:,:,0], number=1)
casadi_time = timeit.timeit(lambda: np.array(casadi_sequential_evaluation(q_vals))[:, :, 0], number=1)
print(f"Time: {casadi_time:.4f} seconds")

print(f"\nJAX vectorized evaluation ({N} times):")
jax_time = timeit.timeit(lambda: np.array(jax_fn_vectorized(jax_q_vals))[0,:,:,0], number=1)
jax_time = timeit.timeit(lambda: np.array(jax_fn_vectorized(jax_q_vals))[0, :, :, 0], number=1)
print(f"Time: {jax_time:.4f} seconds")

print(f"\nMJX vectorized evaluation ({N} times):")
Expand All @@ -126,8 +141,8 @@ def mjx_fk(joint_pos):

# Verify results
print("\nVerifying performance test results:")
casadi_results = np.array(casadi_sequential_evaluation(q_vals[:100]))[:,:,0]
jax_results = np.array(jax_fn_vectorized(jax_q_vals[:100]))[0,:,:,0]
casadi_results = np.array(casadi_sequential_evaluation(q_vals[:100]))[:, :, 0]
jax_results = np.array(jax_fn_vectorized(jax_q_vals[:100]))[0, :, :, 0]
mjx_results = np.array(mjx_fn_vectorized(mjx_q_vals[:100]))
print("First 100 JAX and Casadi results match:", np.allclose(casadi_results, jax_results, atol=1e-6))
print("First 100 MJX and Casadi results match:", np.allclose(casadi_results, mjx_results, atol=1e-6))

0 comments on commit f2d60e0

Please sign in to comment.