diff --git a/jaxadi/_convert.py b/jaxadi/_convert.py index e0ab5e7..a2dcd52 100644 --- a/jaxadi/_convert.py +++ b/jaxadi/_convert.py @@ -6,6 +6,7 @@ from ._graph import translate as graph_translate from ._expand import translate as expand_translate from ._compile import compile as compile_fn +from ._preprocess import densify def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., Any]: @@ -21,6 +22,8 @@ def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., if translate is None: translate = graph_translate + casadi_fn = densify(casadi_fn) + jax_str = translate(casadi_fn) jax_fn = declare(jax_str) diff --git a/jaxadi/_preprocess.py b/jaxadi/_preprocess.py new file mode 100644 index 0000000..f5ec3ae --- /dev/null +++ b/jaxadi/_preprocess.py @@ -0,0 +1,14 @@ +from casadi import densify as cs_densify +from casadi import Function + + +def densify(func: Function): + _i = func.sx_in() + _o = func(*_i) + if not isinstance(_o, tuple): + _o = [_o] + _dense_o = [] + for i, o in enumerate(_o): + _dense_o.append(cs_densify(o)) + _func = Function(func.name(), _i, _dense_o) + return _func diff --git a/tests/test_casadi_equality.py b/tests/test_casadi_equality.py index 95e3df3..3fc2457 100644 --- a/tests/test_casadi_equality.py +++ b/tests/test_casadi_equality.py @@ -5,6 +5,9 @@ from jaxadi import convert +from jaxadi import graph_translate +from jaxadi import expand_translate + # Set a fixed seed for reproducibility np.random.seed(42) @@ -35,6 +38,29 @@ def test_simo_trig(): compare_results(casadi_f, jax_f, x_val) +def test_all_zeros(): + X = ca.SX.sym("x", 2) + A = np.zeros((2, 2)) + Y = ca.jacobian(A @ X, X) + + casadi_f = ca.Function("foo", [X], [Y]) + jax_f = convert(casadi_f) + x_val = np.random.randn(2, 1) + compare_results(casadi_f, jax_f, x_val) + + +def test_structural_zeros(): + X = ca.SX.sym("x", 2) + A = np.ones((2, 2)) + A[1, :] = 0.0 + Y = ca.jacobian(A @ X, X) + + casadi_f = ca.Function("foo", [X], [ca.densify(Y)]) + jax_f = convert(casadi_f, translate=expand_translate) + x_val = np.random.randn(2, 1) + compare_results(casadi_f, jax_f, x_val) + + def test_simo_poly(): x = ca.SX.sym("x", 1, 1) casadi_f = ca.Function("simo_poly", [x], [x**2, x**3, ca.sqrt(x)])