diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index eb0a340de..63fde8167 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -635,7 +635,7 @@ def _spatial_norm_callback( spatial = TaggedArray._extract_data(adata, attr=attr, key=key) logger.info(f"Normalizing spatial coordinates of `{term}`.") - spatial = (spatial - spatial.mean()) / spatial.std() + spatial = (spatial - spatial.mean(axis=0)) / spatial.std() return TaggedArray(spatial, tag=Tag.POINT_CLOUD) @staticmethod diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 0ef586bc0..fe3d9824a 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -51,8 +51,7 @@ def test_prepare_sequential( ap = ap.prepare(batch_key="batch", joint_attr=joint_attr, normalize_spatial=normalize_spatial) assert len(ap) == 2 if normalize_spatial: - np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].x.data_src.std(), atol=1e-15) - np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), 1.0, atol=1e-15) + np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].y.data_src.std(), atol=1e-15) np.testing.assert_allclose(ap[("1", "2")].x.data_src.mean(), 0, atol=1e-15) np.testing.assert_allclose(ap[("0", "1")].x.data_src.mean(), 0, atol=1e-15) @@ -75,7 +74,6 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), [(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)], @@ -102,8 +100,7 @@ def test_solve_balanced( ) for prob_key in ap: assert ap[prob_key].solution.rank == rank - if initializer != "random": # TODO: is this valid? - assert ap[prob_key].solution.converged + assert ap[prob_key].solution.converged # TODO(michalk8): use np.testing assert np.allclose(*(sol.cost for sol in ap.solutions.values())) diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 485ea52ff..e0a4e99ed 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -96,7 +96,6 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List assert prob.x.data_src.shape == (n_obs, x_n_var) assert prob.y.data_src.shape == (n_obs, y_n_var) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), [(1e-2, 0.9, -1, None), (2, 0.5, 10, "random"), (2, 0.5, 10, "rank2"), (2, 0.1, -1, None)], diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a22d5cd35..b95853863 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -60,7 +60,6 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys - @pytest.mark.skip(reason="unbalanced does not work yet") def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): taus = [9e-1, 1e-2] problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal) @@ -76,8 +75,8 @@ def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): assert problem2[0, 1].a is not None assert problem2[0, 1].b is not None - problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0], max_iterations=10000) - problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1], max_iterations=10000) + problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0]) + problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1]) assert problem1[0, 1].solution.a is not None assert problem1[0, 1].solution.b is not None diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 43923f285..01a0067c8 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -78,8 +78,8 @@ def test_solve_unbalanced(self, adata_time: AnnData): assert problem2[0, 1].a is not None assert problem2[0, 1].b is not None - problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0], max_iterations=10000) - problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1], max_iterations=10000) + problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0]) + problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1]) assert problem1[0, 1].solution.a is not None assert problem1[0, 1].solution.b is not None