From 9896652198cb587210a4e8075a95fffeff7855e1 Mon Sep 17 00:00:00 2001 From: Marco Sirabella Date: Fri, 9 Jun 2023 00:38:10 -0700 Subject: [PATCH] Make historical records' m2m fields type-compatible with non-historical Fixes #1186 --- simple_history/manager.py | 29 +++++++++++++++++++++++ simple_history/models.py | 8 +++++-- simple_history/tests/tests/test_models.py | 26 ++++++++++---------- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/simple_history/manager.py b/simple_history/manager.py index e91b2491b..48549561d 100644 --- a/simple_history/manager.py +++ b/simple_history/manager.py @@ -127,6 +127,35 @@ def __get__(self, instance, owner): return HistoryManager.from_queryset(HistoricalQuerySet)(self.model, instance) +class HistoryManyToManyDescriptor: + def __init__(self, model, rel): + self.rel = rel + self.model = model + + def __get__(self, instance, owner): + return HistoryManyRelatedManager.from_queryset(QuerySet)( + self.model, self.rel, instance + ) + + +class HistoryManyRelatedManager(models.Manager): + def __init__(self, through, rel, instance=None): + super().__init__() + self.model = rel.model + self.through = through + self.instance = instance + self._m2m_through_field_name = rel.field.m2m_reverse_field_name() + + def get_queryset(self): + qs = super().get_queryset() + through_qs = HistoryManager.from_queryset(HistoricalQuerySet)( + self.through, self.instance + ) + return qs.filter( + pk__in=through_qs.all().values_list(self._m2m_through_field_name, flat=True) + ) + + class HistoryManager(models.Manager): def __init__(self, model, instance=None): super().__init__() diff --git a/simple_history/models.py b/simple_history/models.py index db19c66fa..6900e3c1c 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -31,7 +31,11 @@ from simple_history import utils from . import exceptions -from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor +from .manager import ( + SIMPLE_HISTORY_REVERSE_ATTR_NAME, + HistoryDescriptor, + HistoryManyToManyDescriptor, +) from .signals import ( post_create_historical_m2m_records, post_create_historical_record, @@ -227,7 +231,7 @@ def finalize(self, sender, **kwargs): setattr(module, m2m_model.__name__, m2m_model) - m2m_descriptor = HistoryDescriptor(m2m_model) + m2m_descriptor = HistoryManyToManyDescriptor(m2m_model, field.remote_field) setattr(history_model, field.name, m2m_descriptor) def get_history_model_name(self, model): diff --git a/simple_history/tests/tests/test_models.py b/simple_history/tests/tests/test_models.py index 5dd87aeb8..7290818ab 100644 --- a/simple_history/tests/tests/test_models.py +++ b/simple_history/tests/tests/test_models.py @@ -1788,12 +1788,12 @@ def test_separation(self): self.assertEqual(book.restaurants.all().count(), 0) self.assertEqual(book.books.all().count(), 1) self.assertEqual(book.places.all().count(), 1) - self.assertEqual(book.books.first().book, self.book) + self.assertEqual(book.books.first(), self.book) self.assertEqual(place.restaurants.all().count(), 0) self.assertEqual(place.books.all().count(), 0) self.assertEqual(place.places.all().count(), 1) - self.assertEqual(place.places.first().place, self.place) + self.assertEqual(place.places.first(), self.place) self.assertEqual(add.restaurants.all().count(), 0) self.assertEqual(add.books.all().count(), 0) @@ -1829,11 +1829,11 @@ def test_separation(self): self.assertEqual(book.books.all().count(), 1) self.assertEqual(book.places.all().count(), 1) - self.assertEqual(book.books.first().book, self.book) + self.assertEqual(book.books.first(), self.book) self.assertEqual(place.books.all().count(), 0) self.assertEqual(place.places.all().count(), 1) - self.assertEqual(place.places.first().place, self.place) + self.assertEqual(place.places.first(), self.place) self.assertEqual(add.books.all().count(), 0) self.assertEqual(add.places.all().count(), 0) @@ -1842,11 +1842,11 @@ def test_separation(self): self.assertEqual(restaurant.restaurants.all().count(), 1) self.assertEqual(restaurant.places.all().count(), 1) - self.assertEqual(restaurant.restaurants.first().restaurant, self.restaurant) + self.assertEqual(restaurant.restaurants.first(), self.restaurant) self.assertEqual(place.restaurants.all().count(), 0) self.assertEqual(place.places.all().count(), 1) - self.assertEqual(place.places.first().place, self.place) + self.assertEqual(place.places.first(), self.place) self.assertEqual(add.restaurants.all().count(), 0) self.assertEqual(add.places.all().count(), 0) @@ -1964,7 +1964,7 @@ def test_create(self): # And the historical place is the correct one historical_place = m2m_record.places.first() - self.assertEqual(historical_place.place, self.place) + self.assertEqual(historical_place, self.place) def test_remove(self): # Add and remove a many-to-many child @@ -1984,7 +1984,7 @@ def test_remove(self): # And the previous row still has the correct one historical_place = previous_m2m_record.places.first() - self.assertEqual(historical_place.place, self.place) + self.assertEqual(historical_place, self.place) def test_clear(self): # Add some places @@ -2036,7 +2036,7 @@ def test_delete_child(self): # Place instance cannot be created... historical_place = m2m_record.places.first() with self.assertRaises(ObjectDoesNotExist): - historical_place.place.id + historical_place.id # But the values persist historical_place_values = m2m_record.places.all().values()[0] @@ -2066,7 +2066,7 @@ def test_delete_parent(self): # And it is the correct one historical_place = prev_record.places.first() - self.assertEqual(historical_place.place, self.place) + self.assertEqual(historical_place, self.place) def test_update_child(self): self.poll.places.add(self.place) @@ -2084,7 +2084,7 @@ def test_update_child(self): m2m_record = self.poll.history.all()[0] self.assertEqual(m2m_record.places.count(), 1) historical_place = m2m_record.places.first() - self.assertEqual(historical_place.place.name, "Updated") + self.assertEqual(historical_place.name, "Updated") def test_update_parent(self): self.poll.places.add(self.place) @@ -2102,7 +2102,7 @@ def test_update_parent(self): m2m_record = self.poll.history.all()[0] self.assertEqual(m2m_record.places.count(), 1) historical_place = m2m_record.places.first() - self.assertEqual(historical_place.place, self.place) + self.assertEqual(historical_place, self.place) def test_bulk_add_remove(self): # Add some places @@ -2134,7 +2134,7 @@ def test_bulk_add_remove(self): self.assertEqual(m2m_record.places.all().count(), 1) historical_place = m2m_record.places.first() - self.assertEqual(historical_place.place, self.place) + self.assertEqual(historical_place, self.place) def test_m2m_relation(self): # Ensure only the correct M2Ms are saved and returned for history objects