From 58d9df21d1edf030a3745fbdb9580ea89d55a7c7 Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 1/8] Add support for nested creation/update in mutations. This also has the benefit of consistently calling `full_clean()` before creating related instances. This does remove the `get_or_create()` calls and instead uses `create` only. The expectation here is that `key_attr` could and should be used to indicate what field should be used as the unique identifier, and not something hard coded that could have unintended side effects when creating related instances that don't have unique constraints and expect new instances to always be created. --- strawberry_django/mutations/resolvers.py | 90 ++++++++++++++++++++---- 1 file changed, 75 insertions(+), 15 deletions(-) diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 50c76677..521ffc12 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -14,6 +14,7 @@ import strawberry from django.db import models, transaction +from django.db.models.manager import BaseManager from django.db.models.base import Model from django.db.models.fields.related import ManyToManyField from django.db.models.fields.reverse_related import ( @@ -222,6 +223,7 @@ def prepare_create_update( data: dict[str, Any], key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, + exclude_m2m: list[str] | None = None, ) -> tuple[ Model, dict[str, object], @@ -237,6 +239,7 @@ def prepare_create_update( fields = get_model_fields(model) m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = [] direct_field_values: dict[str, object] = {} + exclude_m2m = exclude_m2m or [] if dataclasses.is_dataclass(data): data = vars(data) @@ -256,6 +259,8 @@ def prepare_create_update( # (but only if the instance is already saved and we are updating it) value = False # noqa: PLW2901 elif isinstance(field, (ManyToManyField, ForeignObjectRel)): + if name in exclude_m2m: + continue # m2m will be processed later m2m.append((field, value)) direct_field_value = False @@ -309,6 +314,7 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -321,6 +327,7 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... @@ -333,12 +340,43 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, +) -> list[_M] | _M: + return _create( + info, + model._default_manager, + data, + key_attr=key_attr, + full_clean=full_clean, + pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, + ) + + +@transaction.atomic +def _create( + info: Info, + manager: BaseManager, + data: dict[str, Any] | list[dict[str, Any]], + *, + key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, + pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M] | _M: + model = manager.model # Before creating your instance, verify this is not a bulk create # if so, add them one by one. Otherwise, get to work. if isinstance(data, list): return [ - create(info, model, d, key_attr=key_attr, full_clean=full_clean) + create( + info, + model, + d, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) for d in data ] @@ -365,6 +403,7 @@ def create( data=data, full_clean=full_clean, key_attr=key_attr, + exclude_m2m=exclude_m2m, ) # Creating the instance directly via create() without full-clean will @@ -376,7 +415,7 @@ def create( # Create the instance using the manager create method to respect # manager create overrides. This also ensures support for proxy-models. - instance = model._default_manager.create(**create_kwargs) + instance = manager.create(**create_kwargs) for field, value in m2m: update_m2m(info, instance, field, value, key_attr) @@ -393,6 +432,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -405,6 +445,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... @@ -417,6 +458,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M | list[_M]: # Unwrap lazy objects since they have a proxy __iter__ method that will make # them iterables even if the wrapped object isn't @@ -433,6 +475,7 @@ def update( key_attr=key_attr, full_clean=full_clean, pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, ) for instance in instances ] @@ -443,6 +486,7 @@ def update( data=data, key_attr=key_attr, full_clean=full_clean, + exclude_m2m=exclude_m2m, ) if pre_save_hook is not None: @@ -554,15 +598,22 @@ def update_m2m( use_remove = True if isinstance(field, ManyToManyField): manager = cast("RelatedManager", getattr(instance, field.attname)) + reverse_field_name = field.remote_field.related_name # type: ignore else: assert isinstance(field, (ManyToManyRel, ManyToOneRel)) accessor_name = field.get_accessor_name() + reverse_field_name = field.field.name assert accessor_name manager = cast("RelatedManager", getattr(instance, accessor_name)) if field.one_to_many: # remove if field is nullable, otherwise delete use_remove = field.remote_field.null is True + # Create a data dict containing the reference to the instance and exclude it from + # nested m2m creation (to break circular references) + ref_instance_data = {reverse_field_name: instance} + exclude_m2m = [reverse_field_name] + to_add = [] to_remove = [] to_delete = [] @@ -621,14 +672,17 @@ def update_m2m( existing.discard(obj) else: - if key_attr not in data: # we have a Input Type - obj, _ = manager.get_or_create(**data) - else: - data.pop(key_attr) - obj = manager.create(**data) - - if full_clean: - obj.full_clean(**full_clean_options) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + obj = _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) existing.discard(obj) for remaining in existing: @@ -656,11 +710,17 @@ def update_m2m( data.pop(key_attr, None) to_add.append(obj) elif data: - if key_attr not in data: - manager.get_or_create(**data) - else: - data.pop(key_attr) - manager.create(**data) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) else: raise AssertionError From 21e86d0dc2dbd819f89c7b846f79f3bfddb112f2 Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 2/8] Formatting --- strawberry_django/mutations/resolvers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 521ffc12..20375ced 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -14,7 +14,6 @@ import strawberry from django.db import models, transaction -from django.db.models.manager import BaseManager from django.db.models.base import Model from django.db.models.fields.related import ManyToManyField from django.db.models.fields.reverse_related import ( @@ -45,7 +44,11 @@ ) if TYPE_CHECKING: - from django.db.models.manager import ManyToManyRelatedManager, RelatedManager + from django.db.models.manager import ( + BaseManager, + ManyToManyRelatedManager, + RelatedManager, + ) from strawberry.types.info import Info From 7fbd72344698989f9593c9f77efd9d8c27ce1f04 Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 3/8] First test (heavily based on one from an existing PR) --- tests/projects/schema.py | 6 + tests/projects/snapshots/schema.gql | 5 + .../snapshots/schema_with_inheritance.gql | 5 + tests/test_input_mutations.py | 136 ++++++++++++++++++ 4 files changed, 152 insertions(+) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 669ab286..27408d1b 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -335,6 +335,11 @@ class MilestoneIssueInput: name: strawberry.auto +@strawberry_django.partial(Issue) +class MilestoneIssueInputPartial: + name: strawberry.auto + + @strawberry_django.partial(Project) class ProjectInputPartial(NodeInputPartial): name: strawberry.auto @@ -351,6 +356,7 @@ class MilestoneInput: @strawberry_django.partial(Milestone) class MilestoneInputPartial(NodeInputPartial): name: strawberry.auto + issues: Optional[list[MilestoneIssueInputPartial]] @strawberry.type diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index b5f31d33..a883126b 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -365,12 +365,17 @@ input MilestoneInput { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] } input MilestoneIssueInput { name: String! } +input MilestoneIssueInputPartial { + name: String +} + input MilestoneOrder { name: Ordering project: ProjectOrder diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 352c8e1b..f344defc 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -159,6 +159,11 @@ input MilestoneFilter { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] +} + +input MilestoneIssueInputPartial { + name: String } input MilestoneOrder { diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index f48d54e3..dbfe3837 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -377,6 +377,142 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient): } +@pytest.mark.django_db(transaction=True) +def test_input_update_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation UpdateProject ($input: ProjectInputPartial!) { + updateProject (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + } + } + } + } + } + """ + + project = ProjectFactory.create(name="Some Project") + + res = gql_client.query( + query, + { + "input": { + "id": to_base64("ProjectType", project.pk), + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + }, + { + "name": "Another Issue", + }, + { + "name": "Third issue", + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["updateProject"], dict) + + project_typename, project_pk = from_base64(res.data["updateProject"].pop("id")) + assert project_typename == "ProjectType" + assert project.pk == int(project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["updateProject"]["milestones"]) == 2 + + some_milestone = res.data["updateProject"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["updateProject"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + assert res.data == { + "updateProject": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + }, + { + "name": "Another Issue", + }, + { + "name": "Third issue", + }, + ], + }, + ], + }, + } + + @pytest.mark.django_db(transaction=True) def test_input_update_mutation(db, gql_client: GraphQLTestClient): query = """ From 05412b8c3093898e3113e8004e4bb06d5dc7471e Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 4/8] Update new test with m2m creation/use --- tests/projects/schema.py | 1 + tests/projects/snapshots/schema.gql | 1 + .../snapshots/schema_with_inheritance.gql | 6 +++ tests/test_input_mutations.py | 51 ++++++++++++++++++- 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 27408d1b..68654edb 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -338,6 +338,7 @@ class MilestoneIssueInput: @strawberry_django.partial(Issue) class MilestoneIssueInputPartial: name: strawberry.auto + tags: Optional[list[TagInputPartial]] @strawberry_django.partial(Project) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index a883126b..ad6fd007 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -374,6 +374,7 @@ input MilestoneIssueInput { input MilestoneIssueInputPartial { name: String + tags: [TagInputPartial!] } input MilestoneOrder { diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index f344defc..30611fef 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -164,6 +164,7 @@ input MilestoneInputPartial { input MilestoneIssueInputPartial { name: String + tags: [TagInputPartial!] } input MilestoneOrder { @@ -406,6 +407,11 @@ input StrFilterLookup { iRegex: String } +input TagInputPartial { + id: GlobalID + name: String +} + type TagType implements Node & Named { """The Globally Unique ID of this object""" id: GlobalID! diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index dbfe3837..d9318629 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -10,7 +10,7 @@ TagFactory, UserFactory, ) -from .projects.models import Issue, Milestone, Project +from .projects.models import Issue, Milestone, Project, Tag @pytest.mark.django_db(transaction=True) @@ -401,6 +401,9 @@ def test_input_update_mutation_with_multiple_level_nested_creation( issues { id name + tags { + name + } } } } @@ -410,6 +413,9 @@ def test_input_update_mutation_with_multiple_level_nested_creation( project = ProjectFactory.create(name="Some Project") + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + res = gql_client.query( query, { @@ -421,6 +427,12 @@ def test_input_update_mutation_with_multiple_level_nested_creation( "issues": [ { "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], } ], }, @@ -429,12 +441,24 @@ def test_input_update_mutation_with_multiple_level_nested_creation( "issues": [ { "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], }, { "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], }, { "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], }, ], }, @@ -481,6 +505,13 @@ def test_input_update_mutation_with_multiple_level_nested_creation( assert issue_typename == "IssueType" assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + assert res.data == { "updateProject": { "__typename": "ProjectType", @@ -491,6 +522,12 @@ def test_input_update_mutation_with_multiple_level_nested_creation( "issues": [ { "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], } ], }, @@ -499,12 +536,24 @@ def test_input_update_mutation_with_multiple_level_nested_creation( "issues": [ { "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], }, { "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], }, { "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], }, ], }, From e16d70ecaf0b0135e8ebfef82ef27b9b2ba09ade Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 5/8] Add test for nested creation when creating a new resource --- tests/projects/schema.py | 5 + tests/projects/snapshots/schema.gql | 3 + tests/test_input_mutations.py | 186 ++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+) diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 68654edb..2262834a 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -526,6 +526,11 @@ class Mutation: argument_name="input", key_attr="name", ) + create_project_with_milestones: ProjectType = mutations.create( + ProjectInputPartial, + handle_django_errors=True, + argument_name="input", + ) update_project: ProjectType = mutations.update( ProjectInputPartial, handle_django_errors=True, diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index ad6fd007..ef728570 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -95,6 +95,8 @@ input CreateProjectInput { union CreateProjectPayload = ProjectType | OperationInfo +union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo + input CreateQuizInput { title: String! fullCleanOptions: Boolean! = false @@ -437,6 +439,7 @@ type Mutation { updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload! deleteIssue(input: NodeInput!): DeleteIssuePayload! deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload! + createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload! updateProject(input: ProjectInputPartial!): UpdateProjectPayload! createMilestone(input: MilestoneInput!): CreateMilestonePayload! createProject( diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index d9318629..9eb9a89b 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -377,6 +377,192 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient): } +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + + projects = Project.objects.all() + project_typename, project_pk = from_base64( + res.data["createProjectWithMilestones"].pop("id") + ) + assert project_typename == "ProjectType" + assert projects[0] == Project.objects.get(pk=project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["createProjectWithMilestones"]["milestones"]) == 2 + + some_milestone = res.data["createProjectWithMilestones"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["createProjectWithMilestones"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + + assert res.data == { + "createProjectWithMilestones": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], + }, + ], + }, + ], + }, + } + + @pytest.mark.django_db(transaction=True) def test_input_update_mutation_with_multiple_level_nested_creation( db, gql_client: GraphQLTestClient From ecee13673998f4760df575cf560100b9e4be434f Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 6/8] Add test for full_clean being called when performing nested creation or resources --- tests/test_input_mutations.py | 105 ++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index 9eb9a89b..ef35057e 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + import pytest +from django.core.exceptions import ValidationError from strawberry.relay import from_base64, to_base64 from tests.utils import GraphQLTestClient, assert_num_queries @@ -748,6 +751,108 @@ def test_input_update_mutation_with_multiple_level_nested_creation( } +@pytest.mark.parametrize("mock_model", ["Milestone", "Issue", "Tag"]) +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_nested_calls_nested_full_clean( + db, gql_client: GraphQLTestClient, mock_model: str +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + with patch( + f"tests.projects.models.{mock_model}.clean", + side_effect=ValidationError({"name": ValidationError("Invalid name")}), + ) as mocked_full_clean: + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + assert res.data["createProjectWithMilestones"]["__typename"] == "OperationInfo" + assert mocked_full_clean.call_count == 1 + assert res.data["createProjectWithMilestones"]["messages"] == [ + {"field": "name", "kind": "VALIDATION", "message": "Invalid name"} + ] + + @pytest.mark.django_db(transaction=True) def test_input_update_mutation(db, gql_client: GraphQLTestClient): query = """ From b73c0fbe57100af5d5ba255a1f448a769fe4f15b Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Tue, 10 Dec 2024 11:37:54 +1100 Subject: [PATCH 7/8] Remove unecessary `@transaction.atomic()` call --- strawberry_django/mutations/resolvers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 20375ced..bc5d8b28 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -334,7 +334,6 @@ def create( ) -> list[_M]: ... -@transaction.atomic def create( info: Info, model: type[_M], From e719aa7b59c3191c96518d9fbfb495ae4b55ad7f Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Mon, 16 Dec 2024 15:05:49 +1100 Subject: [PATCH 8/8] Add support for nested creation of ForeignKeys --- strawberry_django/mutations/resolvers.py | 43 +++++++++++--- tests/projects/schema.py | 1 + tests/projects/snapshots/schema.gql | 1 + .../snapshots/schema_with_inheritance.gql | 7 +++ tests/test_input_mutations.py | 57 +++++++++++++------ 5 files changed, 85 insertions(+), 24 deletions(-) diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index bc5d8b28..53a58cbe 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -15,6 +15,7 @@ import strawberry from django.db import models, transaction from django.db.models.base import Model +from django.db.models.fields import Field from django.db.models.fields.related import ManyToManyField from django.db.models.fields.reverse_related import ( ForeignObjectRel, @@ -92,6 +93,7 @@ def _parse_data( value: Any, *, key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, ): obj, data = _parse_pk(value, model, key_attr=key_attr) parsed_data = {} @@ -101,10 +103,21 @@ def _parse_data( continue if isinstance(v, ParsedObject): - if v.pk is None: - v = create(info, model, v.data or {}) # noqa: PLW2901 + if v.pk in {None, UNSET}: + related_field = cast("Field", get_model_fields(model).get(k)) + related_model = related_field.related_model + v = create( # noqa: PLW2901 + info, + cast("type[Model]", related_model), + v.data or {}, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=[related_field.name], + ) elif isinstance(v.pk, models.Model) and v.data: - v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901 + v = update( # noqa: PLW2901 + info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean + ) else: v = v.pk # noqa: PLW2901 @@ -277,14 +290,19 @@ def prepare_create_update( cast("type[Model]", field.related_model), value, key_attr=key_attr, + full_clean=full_clean, ) if value is None and not value_data: value = None # noqa: PLW2901 # If foreign object is not found, then create it - elif value is None: - value = field.related_model._default_manager.create( # noqa: PLW2901 - **value_data, + elif value in {None, UNSET}: + value = create( # noqa: PLW2901 + info, + field.related_model, + value_data, + key_attr=key_attr, + full_clean=full_clean, ) # If foreign object does not need updating, then skip it @@ -634,7 +652,11 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(values) for v in values: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) if obj: data.pop(key_attr, None) @@ -701,6 +723,7 @@ def update_m2m( cast("type[Model]", manager.model), v, key_attr=key_attr, + full_clean=full_clean, ) if obj and data: data.pop(key_attr, None) @@ -729,7 +752,11 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(value.remove) for v in value.remove or []: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) data.pop(key_attr, None) assert not data diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 2262834a..f61ae7c8 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -358,6 +358,7 @@ class MilestoneInput: class MilestoneInputPartial(NodeInputPartial): name: strawberry.auto issues: Optional[list[MilestoneIssueInputPartial]] + project: Optional[ProjectInputPartial] @strawberry.type diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index ef728570..777b1c27 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -368,6 +368,7 @@ input MilestoneInputPartial { id: GlobalID name: String issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial } input MilestoneIssueInput { diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 30611fef..a99f50e4 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -160,6 +160,7 @@ input MilestoneInputPartial { id: GlobalID name: String issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial } input MilestoneIssueInputPartial { @@ -310,6 +311,12 @@ type PageInfo { endCursor: String } +input ProjectInputPartial { + id: GlobalID + name: String + milestones: [MilestoneInputPartial!] +} + input ProjectOrder { id: Ordering name: Ordering diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index ef35057e..ce405202 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -248,8 +248,8 @@ def test_input_create_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateMilestone ($input: MilestoneInput!) { - createMilestone (input: $input) { + mutation CreateIssue ($input: IssueInput!) { + createIssue (input: $input) { __typename ... on OperationInfo { messages { @@ -258,45 +258,70 @@ def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient message } } - ... on MilestoneType { + ... on IssueType { id name - project { + milestone { id name + project { + id + name + } } } } } """ assert not Project.objects.filter(name="New Project").exists() + assert not Milestone.objects.filter(name="New Milestone").exists() + assert not Issue.objects.filter(name="New Issue").exists() res = gql_client.query( query, { "input": { - "name": "Some Milestone", - "project": { - "name": "New Project", + "name": "New Issue", + "milestone": { + "name": "New Milestone", + "project": { + "name": "New Project", + }, }, }, }, ) + assert res.data - assert isinstance(res.data["createMilestone"], dict) + assert isinstance(res.data["createIssue"], dict) - typename, _pk = from_base64(res.data["createMilestone"].pop("id")) - assert typename == "MilestoneType" + typename, pk = from_base64(res.data["createIssue"].get("id")) + + assert typename == "IssueType" + issue = Issue.objects.get(pk=pk) + assert issue.name == "New Issue" + + milestone = Milestone.objects.get(name="New Milestone") + assert milestone.name == "New Milestone" project = Project.objects.get(name="New Project") + assert project.name == "New Project" + + assert milestone.project_id == project.pk + assert issue.milestone_id == milestone.pk assert res.data == { - "createMilestone": { - "__typename": "MilestoneType", - "name": "Some Milestone", - "project": { - "id": to_base64("ProjectType", project.pk), - "name": project.name, + "createIssue": { + "__typename": "IssueType", + "id": to_base64("IssueType", issue.pk), + "name": "New Issue", + "milestone": { + "id": to_base64("MilestoneType", milestone.pk), + "name": "New Milestone", + "project": { + "id": to_base64("ProjectType", project.pk), + "name": "New Project", + }, }, }, }