From ed00ac4cfe7fc436dcf19e2709e2a5e7322f848d Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 22 May 2024 10:46:18 +0200 Subject: [PATCH 1/4] center scale --- docs/notebooks | 2 +- src/moscot/base/problems/problem.py | 2 +- tests/problems/space/test_alignment_problem.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/notebooks b/docs/notebooks index 47ec2fb98..620d8f2fe 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 47ec2fb986a7781ff2b43ecb90120c9a4c46df75 +Subproject commit 620d8f2feb7fb724aef335aec63ec79b09e7e1dd diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index b249a46f6..c2e4311bf 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -630,7 +630,7 @@ def _spatial_norm_callback( spatial = TaggedArray._extract_data(adata, attr=attr, key=key) logger.info(f"Normalizing spatial coordinates of `{term}`.") - spatial = (spatial - spatial.mean()) / spatial.std() + spatial = (spatial - spatial.mean(axis=0)) / spatial.std() return TaggedArray(spatial, tag=Tag.POINT_CLOUD) @staticmethod diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index d6a78c063..9eb393227 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -51,8 +51,7 @@ def test_prepare_sequential( ap = ap.prepare(batch_key="batch", joint_attr=joint_attr, normalize_spatial=normalize_spatial) assert len(ap) == 2 if normalize_spatial: - np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].x.data_src.std(), atol=1e-15) - np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), 1.0, atol=1e-15) + np.testing.assert_allclose(ap[("1", "2")].x.data_src.std(), ap[("0", "1")].y.data_src.std(), atol=1e-15) np.testing.assert_allclose(ap[("1", "2")].x.data_src.mean(), 0, atol=1e-15) np.testing.assert_allclose(ap[("0", "1")].x.data_src.mean(), 0, atol=1e-15) From fe37d7b74b6ed9934d62602cd2d6d273a67635b3 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 17 Jun 2024 17:12:52 +0200 Subject: [PATCH 2/4] update --- docs/notebooks | 2 +- tests/problems/space/test_alignment_problem.py | 1 - tests/problems/space/test_mapping_problem.py | 1 - .../spatio_temporal/test_spatio_temporal_problem.py | 6 +++--- tests/problems/time/test_temporal_problem.py | 4 ++-- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/notebooks b/docs/notebooks index 620d8f2fe..c2e091c4d 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 620d8f2feb7fb724aef335aec63ec79b09e7e1dd +Subproject commit c2e091c4d6c754d6aedcc7a2fe4cb69262f49918 diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 9eb393227..9bed13662 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -74,7 +74,6 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), [(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)], diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 903c02611..a824f035d 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -94,7 +94,6 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List assert prob.x.data_src.shape == (n_obs, x_n_var) assert prob.y.data_src.shape == (n_obs, y_n_var) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), [(1e-2, 0.9, -1, None), (2, 0.5, 10, "random"), (2, 0.5, 10, "rank2"), (2, 0.1, -1, None)], diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a22d5cd35..0ff9bc011 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -60,7 +60,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys - @pytest.mark.skip(reason="unbalanced does not work yet") + # @pytest.mark.skip(reason="unbalanced does not work yet") def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): taus = [9e-1, 1e-2] problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal) @@ -76,8 +76,8 @@ def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): assert problem2[0, 1].a is not None assert problem2[0, 1].b is not None - problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0], max_iterations=10000) - problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1], max_iterations=10000) + problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0]) + problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1]) assert problem1[0, 1].solution.a is not None assert problem1[0, 1].solution.b is not None diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 7858eb613..58283edc9 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -78,8 +78,8 @@ def test_solve_unbalanced(self, adata_time: AnnData): assert problem2[0, 1].a is not None assert problem2[0, 1].b is not None - problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0], max_iterations=10000) - problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1], max_iterations=10000) + problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0]) + problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1]) assert problem1[0, 1].solution.a is not None assert problem1[0, 1].solution.b is not None From fd1be932974180572f7644f0821ef5f876c0e6a6 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 17 Jun 2024 17:15:13 +0200 Subject: [PATCH 3/4] remove comment --- tests/problems/spatio_temporal/test_spatio_temporal_problem.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index 0ff9bc011..b95853863 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -60,7 +60,6 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys - # @pytest.mark.skip(reason="unbalanced does not work yet") def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): taus = [9e-1, 1e-2] problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal) From 516eb9b8fb446d2480ea68da8c141148b0c53685 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 17 Jun 2024 17:19:31 +0200 Subject: [PATCH 4/4] update --- tests/problems/space/test_alignment_problem.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 9bed13662..769e2c60e 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -100,8 +100,7 @@ def test_solve_balanced( ) for prob_key in ap: assert ap[prob_key].solution.rank == rank - if initializer != "random": # TODO: is this valid? - assert ap[prob_key].solution.converged + assert ap[prob_key].solution.converged # TODO(michalk8): use np.testing assert np.allclose(*(sol.cost for sol in ap.solutions.values()))