Skip to content

Commit

Permalink
fix(optimizer): Fix nested pagination optimization for m2m relations
Browse files Browse the repository at this point in the history
When annotating something from the relation into the prefetch queryset
for a m2m relation, Django will mistakenly not reuse the existing join
and end up resulting in the generation of spurious results.

There's an ongoing fix for this i this ticket: https://code.djangoproject.com/ticket/35677

This is monkey patching older versions of Django which doesn't contain
the fix, and most likely won't (Django usually only backports
security issues), to fix the issue.

Thanks @SupImDos for providing an MRE in the form of a test for this!

Fix #650
  • Loading branch information
bellini666 committed Dec 25, 2024
1 parent 48b54fd commit 49a8908
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 3 deletions.
5 changes: 5 additions & 0 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,11 @@ def __init__(
self.enable_nested_relations_prefetch = enable_nested_relations_prefetch
self.prefetch_custom_queryset = prefetch_custom_queryset

if enable_nested_relations_prefetch:
from strawberry_django.utils.patches import apply_pagination_fix

apply_pagination_fix()

def on_execute(self) -> Generator[None]:
token = optimizer.set(self)
try:
Expand Down
1 change: 1 addition & 0 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def apply_window_pagination(
).get_order_by()
]

existing_aliases = set(queryset.query.alias_map)
queryset = queryset.annotate(
_strawberry_row_number=_PaginationWindow(
RowNumber(),
Expand Down
78 changes: 78 additions & 0 deletions strawberry_django/utils/patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import django
from django.db import (
DEFAULT_DB_ALIAS,
NotSupportedError,
connections,
)
from django.db.models import Q, Window
from django.db.models.fields import related_descriptors
from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.sql import Query
from django.db.models.sql.constants import INNER
from django.db.models.sql.where import AND


def apply_pagination_fix():
"""Apply pagination fix for Django 5.1 or older.
This is based on the fix in this patch, which is going to be included in Django 5.2:
https://code.djangoproject.com/ticket/35677#comment:9
If can safely be removed when Django 5.2 is the minimum version we support
"""
if django.VERSION >= (5, 2):
return

# This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1
# (there are no differences in this function between these versions)
def _filter_prefetch_queryset(queryset, field_name, instances):
predicate = Q(**{f"{field_name}__in": instances})
db = queryset._db or DEFAULT_DB_ALIAS
if queryset.query.is_sliced:
if not connections[db].features.supports_over_clause:
raise NotSupportedError(
"Prefetching from a limited queryset is only supported on backends "
"that support window functions."
)
low_mark, high_mark = queryset.query.low_mark, queryset.query.high_mark
order_by = [
expr for expr, _ in queryset.query.get_compiler(using=db).get_order_by()
]
window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
predicate &= GreaterThan(window, low_mark)
if high_mark is not None:
predicate &= LessThanOrEqual(window, high_mark)
queryset.query.clear_limits()

# >> ORIGINAL CODE
# return queryset.filter(predicate) # noqa: ERA001
# << ORIGINAL CODE
# >> PATCHED CODE
queryset.query.add_q(predicate, reuse_all_aliases=True)
return queryset
# << PATCHED CODE

related_descriptors._filter_prefetch_queryset = _filter_prefetch_queryset # type: ignore

# This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1
# (there are no differences in this function between these versions)
def add_q(self, q_object, reuse_all_aliases=False):
existing_inner = {
a for a in self.alias_map if self.alias_map[a].join_type == INNER
}
# >> ORIGINAL CODE
# clause, _ = self._add_q(q_object, self.used_aliases) # noqa: ERA001
# << ORIGINAL CODE
# >> PATCHED CODE
if reuse_all_aliases: # noqa: SIM108
can_reuse = set(self.alias_map)
else:
can_reuse = self.used_aliases
clause, _ = self._add_q(q_object, can_reuse)
# << PATCHED CODE
if clause:
self.where.add(clause, AND)
self.demote_joins(existing_inner)

Query.add_q = add_q
10 changes: 8 additions & 2 deletions tests/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ class Meta:


class Issue(NamedModel):
comments: "RelatedManager[Issue]"
issue_assignees: "RelatedManager[Assignee]"
class Meta: # type: ignore
ordering = ("id",)

class Kind(models.TextChoices):
"""Issue kind options."""

BUG = "b", "Bug"
FEATURE = "f", "Feature"

comments: "RelatedManager[Issue]"
issue_assignees: "RelatedManager[Assignee]"

id = models.BigAutoField(
verbose_name="ID",
primary_key=True,
Expand Down Expand Up @@ -203,6 +206,9 @@ class Meta:


class Tag(NamedModel):
class Meta: # type: ignore
ordering = ("id",)

issues: "RelatedManager[Issue]"

id = models.BigAutoField(
Expand Down
78 changes: 77 additions & 1 deletion tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
from strawberry.types import ExecutionResult

import strawberry_django
from strawberry_django.optimizer import DjangoOptimizerExtension
from strawberry_django.pagination import (
OffsetPaginationInput,
apply,
apply_window_pagination,
)
from tests import models, utils
from tests.projects.faker import MilestoneFactory, ProjectFactory
from tests.projects.faker import (
IssueFactory,
MilestoneFactory,
ProjectFactory,
TagFactory,
)


@strawberry_django.type(models.Fruit, pagination=True)
Expand Down Expand Up @@ -145,3 +151,73 @@ def test_apply_window_pagination_with_no_limites(limit):
assert first_fruit.name == "fruit2"
assert first_fruit._strawberry_row_number == 3 # type: ignore
assert first_fruit._strawberry_total_count == 10 # type: ignore


@pytest.mark.django_db(transaction=True)
def test_nested_pagination_m2m(gql_client: utils.GraphQLTestClient):
# Create 2 tags and 3 issues
tags = [TagFactory(name=f"Tag {i + 1}") for i in range(2)]
issues = [IssueFactory(name=f"Issue {i + 1}") for i in range(3)]
# Assign issues 1 and 2 to the 1st tag
# Assign issues 2 and 3 to the 2nd tag
# This means that both tags share the 2nd issue
tags[0].issues.set(issues[:2])
tags[1].issues.set(issues[1:])
# Query the tags with their issues
# We expect only 2 database queries if the optimizer is enabled, otherwise 3 (N+1)
with utils.assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 6):
result = gql_client.query(
"""
query {
tagConn {
totalCount
edges {
node {
name
issues {
totalCount
edges {
node {
name
}
}
}
}
}
}
}
"""
)
# Check the results
assert not result.errors
assert result.data == {
"tagConn": {
"totalCount": 2,
"edges": [
{
"node": {
"name": "Tag 1",
"issues": {
"totalCount": 2,
"edges": [
{"node": {"name": "Issue 1"}},
{"node": {"name": "Issue 2"}},
],
},
}
},
{
"node": {
"name": "Tag 2",
"issues": {
"totalCount": 2,
"edges": [
{"node": {"name": "Issue 2"}},
{"node": {"name": "Issue 3"}},
],
},
}
},
],
}
}

0 comments on commit 49a8908

Please sign in to comment.