diff --git a/strawberry_django/relay.py b/strawberry_django/relay.py index 19649ca4..9309299c 100644 --- a/strawberry_django/relay.py +++ b/strawberry_django/relay.py @@ -188,11 +188,15 @@ def resolve_connection_from_cache( for node in result ] has_previous_page = ( - nodes[0]._strawberry_row_number > 1 # type: ignore + result[0]._strawberry_row_number > 1 # type: ignore + if result + else False + ) + has_next_page = ( + result[-1]._strawberry_row_number < result[-1]._strawberry_total_count # type: ignore if result else False ) - has_next_page = result._strawberry_row_number < result if result else False return cls( edges=edges, diff --git a/tests/relay/test_nested_pagination.py b/tests/relay/test_nested_pagination.py new file mode 100644 index 00000000..94fb8c07 --- /dev/null +++ b/tests/relay/test_nested_pagination.py @@ -0,0 +1,62 @@ +import pytest +from strawberry.relay import to_base64 +from strawberry.relay.types import PREFIX + +from strawberry_django.optimizer import DjangoOptimizerExtension +from tests import utils +from tests.projects.faker import IssueFactory, MilestoneFactory + + +@pytest.mark.django_db(transaction=True) +def test_nested_pagination(gql_client: utils.GraphQLTestClient): + # Nested pagination with the same arguments for the parent and child connections + query = """ + query testNestedConnectionPagination($first: Int, $after: String) { + milestoneConn(first: $first, after: $after) { + edges { + node { + id + issuesWithFilters(first: $first, after: $after) { + edges { + node { + id + } + } + } + } + } + } + } + """ + + # Create 4 milestones, each with 4 issues + nested_data = { + milestone: IssueFactory.create_batch(4, milestone=milestone) + for milestone in MilestoneFactory.create_batch(4) + } + + # Run the nested pagination query + # We expect only 2 database queries if the optimizer is enabled, otherwise 3 (N+1) + with utils.assert_num_queries(2 if DjangoOptimizerExtension.enabled.get() else 3): + result = gql_client.query(query, {"first": 2, "after": to_base64(PREFIX, 0)}) + + # We expect the 2nd and 3rd milestones each with their 2nd and 3rd issues + assert not result.errors + assert result.data == { + "milestoneConn": { + "edges": [ + { + "node": { + "id": to_base64("MilestoneType", milestone.id), + "issuesWithFilters": { + "edges": [ + {"node": {"id": to_base64("IssueType", issue.id)}} + for issue in issues[1:3] + ] + }, + } + } + for milestone, issues in list(nested_data.items())[1:3] + ] + } + }