Skip to content

Proper handling of custom m2m relationships #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,19 @@ def m2m_changed(self, instance, action, attr, pk_set, reverse, **_):
# It should be safe to ~ this since the row must exist to modify m2m on it
self.create_historical_record(instance, "~")

def _get_through_field_name(self, through_table_fields, model):
"""
Find the name of the field in the through table. This is necessary for the
custom through tables where Django conventions don't apply.
"""
foreign_keys = filter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lambda model_field: isinstance(model_field, models.ForeignKey),
through_table_fields,
)
for field in foreign_keys:
if field.related_model == model:
return field.name

def create_historical_record_m2ms(self, history_instance, instance):
for field in history_instance._history_m2m_fields:
m2m_history_model = self.m2m_models[field]
Expand All @@ -668,7 +681,10 @@ def create_historical_record_m2ms(self, history_instance, instance):

insert_rows = []

through_field_name = type(original_instance).__name__.lower()
# find the name of the field in custom or default through table
through_field_name = self._get_through_field_name(
through_model._meta.fields, original_instance._meta.model
)

rows = through_model.objects.filter(**{through_field_name: instance})

Expand Down
22 changes: 22 additions & 0 deletions simple_history/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,28 @@ class PollWithSeveralManyToMany(models.Model):
history = HistoricalRecords(m2m_fields=[places, restaurants, books])


class Tag(models.Model):
name = models.CharField(max_length=100)


class PollTags(models.Model):
poll = models.ForeignKey(
"PollParentWithManyToManyCustomThrough", on_delete=models.CASCADE
)
tag = models.ForeignKey(Tag, on_delete=models.CASCADE)


class PollParentWithManyToManyCustomThrough(models.Model):
question = models.CharField(max_length=200)
pub_date = models.DateTimeField("date published")
tags = models.ManyToManyField(Tag, through=PollTags)

history = HistoricalRecords(
m2m_fields=[tags],
inherit=True,
)


class PollParentWithManyToMany(models.Model):
question = models.CharField(max_length=200)
pub_date = models.DateTimeField("date published")
Expand Down
7 changes: 7 additions & 0 deletions simple_history/tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
PollChildBookWithManyToMany,
PollChildRestaurantWithManyToMany,
PollInfo,
PollParentWithManyToManyCustomThrough,
PollWithAlternativeManager,
PollWithExcludedFieldsWithDefaults,
PollWithExcludedFKField,
Expand Down Expand Up @@ -781,6 +782,12 @@ def test_history_with_unknown_field(self):
with self.assertNumQueries(0):
new_record.diff_against(old_record, excluded_fields=["unknown_field"])

def test_history_with_custom_through_field(self):
PollParentWithManyToManyCustomThrough.objects.create(
question="what's up?", pub_date=today
)
self.assertEqual(PollParentWithManyToManyCustomThrough.objects.count(), 1)


class GetPrevRecordAndNextRecordTestCase(TestCase):
def assertRecordsMatch(self, record_a, record_b):
Expand Down