Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-level nested create/update with model full_clean() #659

139 changes: 114 additions & 25 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import strawberry
from django.db import models, transaction
from django.db.models.base import Model
from django.db.models.fields import Field
from django.db.models.fields.related import ManyToManyField
from django.db.models.fields.reverse_related import (
ForeignObjectRel,
Expand Down Expand Up @@ -44,7 +45,11 @@
)

if TYPE_CHECKING:
from django.db.models.manager import ManyToManyRelatedManager, RelatedManager
from django.db.models.manager import (
BaseManager,
ManyToManyRelatedManager,
RelatedManager,
)
from strawberry.types.info import Info


Expand Down Expand Up @@ -88,6 +93,7 @@
value: Any,
*,
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
):
obj, data = _parse_pk(value, model, key_attr=key_attr)
parsed_data = {}
Expand All @@ -97,10 +103,21 @@
continue

if isinstance(v, ParsedObject):
if v.pk is None:
v = create(info, model, v.data or {}) # noqa: PLW2901
if v.pk in {None, UNSET}:
related_field = cast("Field", get_model_fields(model).get(k))
related_model = related_field.related_model
v = create( # noqa: PLW2901
info,
cast("type[Model]", related_model),
v.data or {},
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=[related_field.name],
)
elif isinstance(v.pk, models.Model) and v.data:
v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901
v = update( # noqa: PLW2901
info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean
)
else:
v = v.pk # noqa: PLW2901

Expand Down Expand Up @@ -222,6 +239,7 @@
data: dict[str, Any],
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
exclude_m2m: list[str] | None = None,
) -> tuple[
Model,
dict[str, object],
Expand All @@ -237,6 +255,7 @@
fields = get_model_fields(model)
m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = []
direct_field_values: dict[str, object] = {}
exclude_m2m = exclude_m2m or []

if dataclasses.is_dataclass(data):
data = vars(data)
Expand All @@ -256,6 +275,8 @@
# (but only if the instance is already saved and we are updating it)
value = False # noqa: PLW2901
elif isinstance(field, (ManyToManyField, ForeignObjectRel)):
if name in exclude_m2m:
continue
# m2m will be processed later
m2m.append((field, value))
direct_field_value = False
Expand All @@ -269,14 +290,19 @@
cast("type[Model]", field.related_model),
value,
key_attr=key_attr,
full_clean=full_clean,
)
if value is None and not value_data:
value = None # noqa: PLW2901

# If foreign object is not found, then create it
elif value is None:
value = field.related_model._default_manager.create( # noqa: PLW2901
**value_data,
elif value in {None, UNSET}:
value = create( # noqa: PLW2901
info,
field.related_model,
value_data,
key_attr=key_attr,
full_clean=full_clean,
)

# If foreign object does not need updating, then skip it
Expand All @@ -284,8 +310,8 @@
pass

else:
update(

Check failure on line 313 in strawberry_django/mutations/resolvers.py

View workflow job for this annotation

GitHub Actions / Typing

No overloads for "update" match the provided arguments (reportCallIssue)
info, value, value_data, full_clean=full_clean, key_attr=key_attr

Check failure on line 314 in strawberry_django/mutations/resolvers.py

View workflow job for this annotation

GitHub Actions / Typing

Argument of type "Any | None" cannot be assigned to parameter "instance" of type "Iterable[_M@update]" in function "update"   Type "Any | None" is not assignable to type "Iterable[_M@update]"     "None" is incompatible with protocol "Iterable[_M@update]"       "__iter__" is not present (reportArgumentType)
)

if direct_field_value:
Expand All @@ -309,6 +335,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -321,10 +348,10 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


@transaction.atomic
def create(
info: Info,
model: type[_M],
Expand All @@ -333,12 +360,43 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
return _create(
info,
model._default_manager,
data,
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)


@transaction.atomic
def _create(
info: Info,
manager: BaseManager,
data: dict[str, Any] | list[dict[str, Any]],
*,
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
model = manager.model
# Before creating your instance, verify this is not a bulk create
# if so, add them one by one. Otherwise, get to work.
if isinstance(data, list):
return [
create(info, model, d, key_attr=key_attr, full_clean=full_clean)
create(
info,
model,
d,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
for d in data
]

Expand All @@ -365,6 +423,7 @@
data=data,
full_clean=full_clean,
key_attr=key_attr,
exclude_m2m=exclude_m2m,
)

# Creating the instance directly via create() without full-clean will
Expand All @@ -372,11 +431,11 @@
# full-clean() to trigger form-validation style error messages.
full_clean_options = full_clean if isinstance(full_clean, dict) else {}
if full_clean:
dummy_instance.full_clean(**full_clean_options) # type: ignore

Check warning on line 434 in strawberry_django/mutations/resolvers.py

View workflow job for this annotation

GitHub Actions / Typing

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

# Create the instance using the manager create method to respect
# manager create overrides. This also ensures support for proxy-models.
instance = model._default_manager.create(**create_kwargs)
instance = manager.create(**create_kwargs)

for field, value in m2m:
update_m2m(info, instance, field, value, key_attr)
Expand All @@ -393,6 +452,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -405,6 +465,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


Expand All @@ -417,6 +478,7 @@
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M | list[_M]:
# Unwrap lazy objects since they have a proxy __iter__ method that will make
# them iterables even if the wrapped object isn't
Expand All @@ -433,6 +495,7 @@
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)
for instance in instances
]
Expand All @@ -443,6 +506,7 @@
data=data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)

if pre_save_hook is not None:
Expand Down Expand Up @@ -554,15 +618,22 @@
use_remove = True
if isinstance(field, ManyToManyField):
manager = cast("RelatedManager", getattr(instance, field.attname))
reverse_field_name = field.remote_field.related_name # type: ignore
else:
assert isinstance(field, (ManyToManyRel, ManyToOneRel))
accessor_name = field.get_accessor_name()
reverse_field_name = field.field.name
assert accessor_name
manager = cast("RelatedManager", getattr(instance, accessor_name))
if field.one_to_many:
# remove if field is nullable, otherwise delete
use_remove = field.remote_field.null is True

# Create a data dict containing the reference to the instance and exclude it from
# nested m2m creation (to break circular references)
ref_instance_data = {reverse_field_name: instance}
exclude_m2m = [reverse_field_name]

to_add = []
to_remove = []
to_delete = []
Expand All @@ -581,7 +652,11 @@
need_remove_cache = need_remove_cache or bool(values)
for v in values:
obj, data = _parse_data(
info, cast("type[Model]", manager.model), v, key_attr=key_attr
info,
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
if obj:
data.pop(key_attr, None)
Expand Down Expand Up @@ -621,14 +696,17 @@

existing.discard(obj)
else:
if key_attr not in data: # we have a Input Type
obj, _ = manager.get_or_create(**data)
else:
data.pop(key_attr)
obj = manager.create(**data)

if full_clean:
obj.full_clean(**full_clean_options)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
data.pop(key_attr, None)
obj = _create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
existing.discard(obj)

for remaining in existing:
Expand All @@ -645,6 +723,7 @@
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
if obj and data:
data.pop(key_attr, None)
Expand All @@ -656,18 +735,28 @@
data.pop(key_attr, None)
to_add.append(obj)
elif data:
if key_attr not in data:
manager.get_or_create(**data)
else:
data.pop(key_attr)
manager.create(**data)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
data.pop(key_attr, None)
_create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
else:
raise AssertionError

need_remove_cache = need_remove_cache or bool(value.remove)
for v in value.remove or []:
obj, data = _parse_data(
info, cast("type[Model]", manager.model), v, key_attr=key_attr
info,
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
data.pop(key_attr, None)
assert not data
Expand Down
13 changes: 13 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ class MilestoneIssueInput:
name: strawberry.auto


@strawberry_django.partial(Issue)
class MilestoneIssueInputPartial:
name: strawberry.auto
tags: Optional[list[TagInputPartial]]


@strawberry_django.partial(Project)
class ProjectInputPartial(NodeInputPartial):
name: strawberry.auto
Expand All @@ -351,6 +357,8 @@ class MilestoneInput:
@strawberry_django.partial(Milestone)
class MilestoneInputPartial(NodeInputPartial):
name: strawberry.auto
issues: Optional[list[MilestoneIssueInputPartial]]
project: Optional[ProjectInputPartial]


@strawberry.type
Expand Down Expand Up @@ -519,6 +527,11 @@ class Mutation:
argument_name="input",
key_attr="name",
)
create_project_with_milestones: ProjectType = mutations.create(
ProjectInputPartial,
handle_django_errors=True,
argument_name="input",
)
update_project: ProjectType = mutations.update(
ProjectInputPartial,
handle_django_errors=True,
Expand Down
10 changes: 10 additions & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ input CreateProjectInput {

union CreateProjectPayload = ProjectType | OperationInfo

union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo

input CreateQuizInput {
title: String!
fullCleanOptions: Boolean! = false
Expand Down Expand Up @@ -365,12 +367,19 @@ input MilestoneInput {
input MilestoneInputPartial {
id: GlobalID
name: String
issues: [MilestoneIssueInputPartial!]
project: ProjectInputPartial
}

input MilestoneIssueInput {
name: String!
}

input MilestoneIssueInputPartial {
name: String
tags: [TagInputPartial!]
}

input MilestoneOrder {
name: Ordering
project: ProjectOrder
Expand Down Expand Up @@ -431,6 +440,7 @@ type Mutation {
updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload!
deleteIssue(input: NodeInput!): DeleteIssuePayload!
deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload!
createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload!
updateProject(input: ProjectInputPartial!): UpdateProjectPayload!
createMilestone(input: MilestoneInput!): CreateMilestonePayload!
createProject(
Expand Down
Loading
Loading