66
66
_determine_device_map ,
67
67
_fetch_index_file ,
68
68
_fetch_index_file_legacy ,
69
- _load_state_dict_into_model ,
70
69
load_model_dict_into_meta ,
71
70
load_state_dict ,
72
71
)
@@ -94,6 +93,31 @@ def __exit__(self, *args, **kwargs):
94
93
95
94
_REGEX_SHARD = re .compile (r"(.*?)-\d{5}-of-\d{5}" )
96
95
96
+
97
+ def _get_load_device_from_device_map (device_map ):
98
+ """
99
+ Determine the device to load weights directly to, if possible.
100
+
101
+ For simple device maps where all components go to the same device,
102
+ we can load directly to that device to avoid CPU memory usage.
103
+ """
104
+ if device_map is None :
105
+ return "cpu"
106
+
107
+ if isinstance (device_map , dict ):
108
+ # Simple case: everything goes to one device
109
+ if "" in device_map :
110
+ return device_map ["" ]
111
+
112
+ # Check if all values map to the same device
113
+ unique_devices = set (device_map .values ())
114
+ if len (unique_devices ) == 1 :
115
+ return next (iter (unique_devices ))
116
+
117
+ # For complex device maps or string strategies, load to CPU first
118
+ return "cpu"
119
+
120
+
97
121
TORCH_INIT_FUNCTIONS = {
98
122
"uniform_" : nn .init .uniform_ ,
99
123
"normal_" : nn .init .normal_ ,
@@ -873,9 +897,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
873
897
revision = kwargs .pop ("revision" , None )
874
898
torch_dtype = kwargs .pop ("torch_dtype" , None )
875
899
subfolder = kwargs .pop ("subfolder" , None )
876
- device_map = kwargs .pop ("device_map" , "auto" )
877
- if device_map == "auto" :
878
- logger .info ("Using automatic device mapping (device_map='auto') for memory-efficient loading" )
900
+ device_map = kwargs .pop ("device_map" , None )
879
901
max_memory = kwargs .pop ("max_memory" , None )
880
902
offload_folder = kwargs .pop ("offload_folder" , None )
881
903
offload_state_dict = kwargs .pop ("offload_state_dict" , None )
@@ -902,7 +924,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
902
924
"Memory-efficient loading requires `accelerate`. Please install accelerate with: \n ```\n pip"
903
925
" install accelerate\n ```\n ."
904
926
)
905
-
927
+
906
928
if not is_torch_version (">=" , "1.9.0" ):
907
929
raise NotImplementedError (
908
930
"Memory-efficient loading requires PyTorch >= 1.9.0. Please update your PyTorch version."
@@ -1133,7 +1155,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1133
1155
state_dict = None
1134
1156
if not is_sharded :
1135
1157
# Time to load the checkpoint
1136
- state_dict = load_state_dict (resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries )
1158
+ # Determine the device to load weights to based on device_map
1159
+ load_device = _get_load_device_from_device_map (device_map )
1160
+ state_dict = load_state_dict (
1161
+ resolved_model_file [0 ],
1162
+ disable_mmap = disable_mmap ,
1163
+ dduf_entries = dduf_entries ,
1164
+ map_location = load_device
1165
+ )
1137
1166
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
1138
1167
model ._fix_state_dict_keys_on_load (state_dict )
1139
1168
@@ -1191,7 +1220,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
1191
1220
"offload_index" : offload_index ,
1192
1221
}
1193
1222
dispatch_model (model , ** device_map_kwargs )
1194
- logger .info (f"Model loaded with device_map: { device_map } " )
1223
+ # Format device map for concise logging
1224
+ if isinstance (device_map , dict ):
1225
+ device_summary = ", " .join ([f"{ k or 'model' } : { v } " for k , v in device_map .items ()])
1226
+ logger .info (f"Model loaded with device_map: {{{ device_summary } }}" )
1195
1227
1196
1228
if hf_quantizer is not None :
1197
1229
hf_quantizer .postprocess_model (model )
@@ -1352,7 +1384,6 @@ def _load_pretrained_model(
1352
1384
1353
1385
mismatched_keys = []
1354
1386
1355
- assign_to_params_buffers = None
1356
1387
error_msgs = []
1357
1388
1358
1389
# Deal with offload
@@ -1385,7 +1416,9 @@ def _load_pretrained_model(
1385
1416
resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
1386
1417
1387
1418
for shard_file in resolved_model_file :
1388
- state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
1419
+ # Determine the device to load weights to based on device_map
1420
+ load_device = _get_load_device_from_device_map (device_map )
1421
+ state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries , map_location = load_device )
1389
1422
1390
1423
def _find_mismatched_keys (
1391
1424
state_dict ,
0 commit comments