4
4
import warnings
5
5
from dataclasses import dataclass
6
6
from functools import partial
7
- from typing import Any , Dict , Iterable , List , Sequence , Union
7
+ from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Sequence , Union
8
8
9
9
import django
10
10
from django .apps import apps
31
31
from django .utils .text import format_lazy
32
32
from django .utils .translation import gettext_lazy as _
33
33
34
- from simple_history import utils
35
-
36
- from . import exceptions
34
+ from . import exceptions , utils
37
35
from .manager import (
38
36
SIMPLE_HISTORY_REVERSE_ATTR_NAME ,
39
37
HistoricalQuerySet ,
46
44
pre_create_historical_m2m_records ,
47
45
pre_create_historical_record ,
48
46
)
49
- from .utils import get_change_reason_from_object
50
47
51
48
try :
52
49
from asgiref .local import Local as LocalContext
53
50
except ImportError :
54
51
from threading import local as LocalContext
55
52
53
+ if TYPE_CHECKING :
54
+ ModelTypeHint = models .Model
55
+ else :
56
+ ModelTypeHint = object
57
+
56
58
registered_models = {}
57
59
58
60
@@ -668,7 +670,7 @@ def get_change_reason_for_object(self, instance, history_type, using):
668
670
Get change reason for object.
669
671
Customize this method to automatically fill change reason from context.
670
672
"""
671
- return get_change_reason_from_object (instance )
673
+ return utils . get_change_reason_from_object (instance )
672
674
673
675
def m2m_changed (self , instance , action , attr , pk_set , reverse , ** _ ):
674
676
if hasattr (instance , "skip_history_when_saving" ):
@@ -941,8 +943,28 @@ def __get__(self, instance, owner):
941
943
return self .model (** values )
942
944
943
945
944
- class HistoricalChanges :
945
- def diff_against (self , old_history , excluded_fields = None , included_fields = None ):
946
+ class HistoricalChanges (ModelTypeHint ):
947
+ def diff_against (
948
+ self ,
949
+ old_history : "HistoricalChanges" ,
950
+ excluded_fields : Iterable [str ] = None ,
951
+ included_fields : Iterable [str ] = None ,
952
+ * ,
953
+ foreign_keys_are_objs = False ,
954
+ ) -> "ModelDelta" :
955
+ """
956
+ :param old_history:
957
+ :param excluded_fields: The names of fields to exclude from diffing.
958
+ This takes precedence over ``included_fields``.
959
+ :param included_fields: The names of the only fields to include when diffing.
960
+ If not provided, all history-tracked fields will be included.
961
+ :param foreign_keys_are_objs: If ``False``, the returned diff will only contain
962
+ the raw PKs of any ``ForeignKey`` fields.
963
+ If ``True``, the diff will contain the actual related model objects
964
+ instead of just the PKs.
965
+ The latter case will necessarily query the database if the related
966
+ objects have not been prefetched (using e.g. ``select_related()``).
967
+ """
946
968
if not isinstance (old_history , type (self )):
947
969
raise TypeError (
948
970
"unsupported type(s) for diffing:"
@@ -965,16 +987,23 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None):
965
987
m2m_fields = set (included_m2m_fields ).difference (excluded_fields )
966
988
967
989
changes = [
968
- * self ._get_field_changes_for_diff (old_history , fields ),
969
- * self ._get_m2m_field_changes_for_diff (old_history , m2m_fields ),
990
+ * self ._get_field_changes_for_diff (
991
+ old_history , fields , foreign_keys_are_objs
992
+ ),
993
+ * self ._get_m2m_field_changes_for_diff (
994
+ old_history , m2m_fields , foreign_keys_are_objs
995
+ ),
970
996
]
997
+ # Sort by field (attribute) name, to ensure a consistent order
998
+ changes .sort (key = lambda change : change .field )
971
999
changed_fields = [change .field for change in changes ]
972
1000
return ModelDelta (changes , changed_fields , old_history , self )
973
1001
974
1002
def _get_field_changes_for_diff (
975
1003
self ,
976
1004
old_history : "HistoricalChanges" ,
977
1005
fields : Iterable [str ],
1006
+ foreign_keys_are_objs : bool ,
978
1007
) -> List ["ModelChange" ]:
979
1008
"""Helper method for ``diff_against()``."""
980
1009
changes = []
@@ -987,6 +1016,14 @@ def _get_field_changes_for_diff(
987
1016
new_value = new_values [field ]
988
1017
989
1018
if old_value != new_value :
1019
+ if foreign_keys_are_objs and isinstance (
1020
+ self ._meta .get_field (field ), ForeignKey
1021
+ ):
1022
+ # Set the fields to their related model objects instead of
1023
+ # the raw PKs from `model_to_dict()`
1024
+ old_value = getattr (old_history , field )
1025
+ new_value = getattr (self , field )
1026
+
990
1027
change = ModelChange (field , old_value , new_value )
991
1028
changes .append (change )
992
1029
@@ -996,14 +1033,18 @@ def _get_m2m_field_changes_for_diff(
996
1033
self ,
997
1034
old_history : "HistoricalChanges" ,
998
1035
m2m_fields : Iterable [str ],
1036
+ foreign_keys_are_objs : bool ,
999
1037
) -> List ["ModelChange" ]:
1000
1038
"""Helper method for ``diff_against()``."""
1001
1039
changes = []
1002
1040
1003
1041
for field in m2m_fields :
1004
- old_m2m_manager = getattr (old_history , field )
1005
- new_m2m_manager = getattr (self , field )
1006
- m2m_through_model_opts = new_m2m_manager .model ._meta
1042
+ original_field_meta = self .instance_type ._meta .get_field (field )
1043
+ reverse_field_name = utils .get_m2m_reverse_field_name (original_field_meta )
1044
+ # Sort the M2M rows by the related object, to ensure a consistent order
1045
+ old_m2m_qs = getattr (old_history , field ).order_by (reverse_field_name )
1046
+ new_m2m_qs = getattr (self , field ).order_by (reverse_field_name )
1047
+ m2m_through_model_opts = new_m2m_qs .model ._meta
1007
1048
1008
1049
# Create a list of field names to compare against.
1009
1050
# The list is generated without the PK of the intermediate (through)
@@ -1014,10 +1055,32 @@ def _get_m2m_field_changes_for_diff(
1014
1055
for f in m2m_through_model_opts .fields
1015
1056
if f .editable and f .name not in ["id" , "m2m_history_id" , "history" ]
1016
1057
]
1017
- old_rows = list (old_m2m_manager .values (* through_model_fields ))
1018
- new_rows = list (new_m2m_manager .values (* through_model_fields ))
1058
+ old_rows = list (old_m2m_qs .values (* through_model_fields ))
1059
+ new_rows = list (new_m2m_qs .values (* through_model_fields ))
1019
1060
1020
1061
if old_rows != new_rows :
1062
+ if foreign_keys_are_objs :
1063
+ fk_fields = [
1064
+ f
1065
+ for f in through_model_fields
1066
+ if isinstance (m2m_through_model_opts .get_field (f ), ForeignKey )
1067
+ ]
1068
+
1069
+ # Set the through fields to their related model objects instead of
1070
+ # the raw PKs from `values()`
1071
+ def rows_with_foreign_key_objs (m2m_qs ):
1072
+ # Replicate the format of the return value of QuerySet.values()
1073
+ return [
1074
+ {
1075
+ through_field : getattr (through_obj , through_field )
1076
+ for through_field in through_model_fields
1077
+ }
1078
+ for through_obj in m2m_qs .select_related (* fk_fields )
1079
+ ]
1080
+
1081
+ old_rows = rows_with_foreign_key_objs (old_m2m_qs )
1082
+ new_rows = rows_with_foreign_key_objs (new_m2m_qs )
1083
+
1021
1084
change = ModelChange (field , old_rows , new_rows )
1022
1085
changes .append (change )
1023
1086
0 commit comments