Skip to content

Commit abd04e4

Browse files
committed
added unit tests for implicit passing of num_logits_to_keep
Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent d483356 commit abd04e4

File tree

10 files changed

+80
-114
lines changed

10 files changed

+80
-114
lines changed

QEfficient/compile/compile_helper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import List, Optional, Tuple
1313

1414
from QEfficient.utils.logging_utils import logger
15-
from QEfficient.utils.constants import NUM_LOGITS_TO_KEEP
1615

1716

1817
def create_and_dump_specializations(
@@ -22,10 +21,10 @@ def create_and_dump_specializations(
2221
path: str,
2322
is_dlm: bool,
2423
full_batch_size: Optional[int] = None,
25-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
24+
num_speculative_tokens: Optional[int] = None,
2625
):
2726
# Create specialization cfgs
28-
decode_seq_len = 1 if num_logits_to_keep is None else num_logits_to_keep+1
27+
decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens+1
2928
specialization_cfgs = [
3029
dict(batch_size=str(batch_size), seq_len=str(prompt_len), ctx_len=str(ctx_len)), # prefill
3130
dict(batch_size=str(batch_size), seq_len=str(decode_seq_len), ctx_len=str(ctx_len)) # decode
@@ -171,7 +170,7 @@ def compile(
171170
path=specialization_json_path,
172171
full_batch_size=full_batch_size,
173172
is_dlm=kwargs.get("is_dlm", False),
174-
num_logits_to_keep=kwargs.get("num_logits_to_keep", NUM_LOGITS_TO_KEEP),
173+
num_speculative_tokens=kwargs.get("num_speculative_tokens", None),
175174
)
176175

177176
# Select the customIO config based on the mx flag.

QEfficient/exporter/export_hf_to_cloud_ai_100.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from QEfficient.transformers.modeling_utils import get_lists_of_cb_qeff_models
2222
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
2323
from QEfficient.utils import load_hf_tokenizer
24-
from QEfficient.utils.constants import QEFF_MODELS_DIR, NUM_LOGITS_TO_KEEP, Constants
24+
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
2525
from QEfficient.utils.generate_inputs import InputHandler
2626
from QEfficient.utils.logging_utils import logger
2727

@@ -196,7 +196,7 @@ def export_kvstyle_transformed_model_to_onnx(
196196
onnx_dir_path: str,
197197
seq_len: int,
198198
full_batch_size: Optional[int] = None,
199-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
199+
num_speculative_tokens: Optional[int] = None,
200200
) -> str:
201201
# Disabling requires_grad on all parameters
202202
for _, p in enumerate(transformed_model.parameters()):
@@ -205,13 +205,15 @@ def export_kvstyle_transformed_model_to_onnx(
205205
if seq_len <= 0:
206206
raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}")
207207

208-
# Implicitly pass "num_logits_to_keep" if defined and \
209-
# assert prompt_len >= num_logits_to_keep
208+
# Implicitly pass "num_speculative_tokens" if defined and \
209+
# assert prompt_len >= num_speculative_tokens
210210
prompt_len = Constants.PROMPT_LEN
211-
if num_logits_to_keep is not None:
212-
setattr(transformed_model, "num_logits_to_keep", num_logits_to_keep+1)
213-
if prompt_len < num_logits_to_keep+1:
214-
prompt_len *= math.ceil((num_logits_to_keep+1) / prompt_len)
211+
num_logits_to_keep = None
212+
if num_speculative_tokens is not None:
213+
num_logits_to_keep = num_speculative_tokens+1
214+
setattr(transformed_model, "num_logits_to_keep", num_logits_to_keep)
215+
if prompt_len < num_logits_to_keep:
216+
prompt_len *= math.ceil((num_logits_to_keep) / prompt_len)
215217

216218
# Preprocess inputs
217219
# Build inputs for prefill
@@ -331,7 +333,7 @@ def export_for_cloud(
331333
onnx_dir_path: str,
332334
seq_length: int = Constants.SEQ_LEN,
333335
full_batch_size: Optional[int] = None,
334-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
336+
num_speculative_tokens: Optional[int] = None,
335337
) -> str:
336338
# Check if model architecture is supported for continuous batching.
337339
if full_batch_size and qeff_model.model.config.architectures[0] not in get_lists_of_cb_qeff_models.architectures:
@@ -348,7 +350,7 @@ def export_for_cloud(
348350
onnx_dir_path=onnx_dir_path,
349351
seq_length=seq_length,
350352
full_batch_size=full_batch_size,
351-
num_logits_to_keep=num_logits_to_keep
353+
num_speculative_tokens=num_speculative_tokens
352354
)
353355
else:
354356
raise NotImplementedError(
@@ -363,7 +365,7 @@ def export_lm_model_for_cloud(
363365
onnx_dir_path: str,
364366
seq_length: int,
365367
full_batch_size: Optional[int] = None,
366-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
368+
num_speculative_tokens: Optional[int] = None,
367369
) -> str:
368370
if os.path.exists(onnx_dir_path):
369371
logger.warning(f"Overriding {onnx_dir_path}")
@@ -377,7 +379,7 @@ def export_lm_model_for_cloud(
377379
onnx_dir_path=onnx_dir_path,
378380
seq_len=seq_length,
379381
full_batch_size=full_batch_size,
380-
num_logits_to_keep=num_logits_to_keep,
382+
num_speculative_tokens=num_speculative_tokens,
381383
) # type: ignore
382384

383385
else:
@@ -403,7 +405,7 @@ def qualcomm_efficient_converter(
403405
kv: bool = True,
404406
form_factor: str = "cloud",
405407
full_batch_size: Optional[int] = None,
406-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
408+
num_speculative_tokens: Optional[int] = None,
407409
) -> Tuple[str, str]:
408410
"""
409411
This method is an alias for ``QEfficient.export``.
@@ -484,7 +486,7 @@ def qualcomm_efficient_converter(
484486
onnx_dir_path=onnx_dir_path,
485487
seq_length=seq_length,
486488
full_batch_size=full_batch_size,
487-
num_logits_to_keep=num_logits_to_keep,
489+
num_speculative_tokens=num_speculative_tokens,
488490
)
489491
return onnx_dir_path, generated_onnx_model_path
490492
else:

QEfficient/exporter/export_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,6 @@ def generate_input_files(
281281
# inputFiles
282282
os.makedirs(input_files_path, exist_ok=True)
283283
filenames = []
284-
if "num_logits_to_keep" in input_names:
285-
idx = input_names.index("num_logits_to_keep")
286-
del input_names[idx]
287284

288285
for name in input_names:
289286
# We can't directly iterate with inputs.items() because

QEfficient/transformers/modeling_spd_utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,6 @@ def filter_hidden_states(
3333
if num_logits_to_keep is None:
3434
# return the last logit
3535
return hidden_states[batch_indices.view(-1, 1), logit_index]
36-
# last valid `num_logits_to_keep` need to be computed
37-
38-
#upper_idx = torch.max(logit_index[0]+1, torch.tensor([num_logits], dtype=torch.int32))
39-
#upper_idx = logit_index[0]+1
40-
#lower_idx = upper_idx - num_logits
41-
#return hidden_states[:, lower_idx:upper_idx] # fails
42-
#return hidden_states[:, lower_idx.item():upper_idx.item()] # works
43-
#return hidden_states[:, lower_idx:upper_idx]
44-
#return hidden_states[batch_indices.view(-1,1), lower_idx:upper_idx]
45-
#return hidden_states[batch_indices.view(-1), lower_idx:upper_idx] # fails: Slice
46-
#return hidden_states[:, lower_idx.unsqueeze(0):upper_idx.unsqueeze(0)] # fails
47-
48-
# range operator approach (onnx pass, compile fail)
49-
#indices = torch.arange(lower_idx[0], upper_idx[0])
50-
#return hidden_states[batch_indices.view(-1,1), indices] # onnx pass, compile fail with: [Operator-'/Range_1'] : Range: Non-constant start tensor not supported.
51-
52-
# range operators approach v2 (onnx pass, compile fails)
53-
#indices = torch.arange(lower_idx[0], upper_idx[0]).repeat(batch_size,1)
54-
#return hidden_states[batch_indices.view(-1,1), indices] # onnx pass, compile fail with: Error message: [Operator-'/Range_1'] : Range: Non-constant start tensor not supported.
55-
56-
# what if we repeat batch_indices to have 1-1 dimensions? (onnx pass, compile fail)
57-
#indices = torch.arange(lower_idx[0], upper_idx[0]).repeat(batch_size,1)
58-
#return hidden_states[batch_indices.view(-1,1).repeat(1,num_logits), indices] # onnx pass, compile fail with: [Operator-'/Range_1'] : Range: Non-constant start tensor not supported
59-
6036
# topk approach
6137
topk_indices = torch.topk(position_ids, k=num_logits_to_keep, dim=1).indices.to(torch.int32)
6238
topk_indices = torch.flip(topk_indices, dims=[1]) # "left" padded input in case num_non_padded_tokens < num_logits_to_keep

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
)
3232

3333
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
34-
from QEfficient.transformers.modeling_spd_utils import filter_hidden_states
3534

3635

3736
class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
@@ -241,7 +240,6 @@ def forward(
241240
output_hidden_states: Optional[bool] = None,
242241
return_dict: Optional[bool] = None,
243242
cache_position: Optional[torch.LongTensor] = None,
244-
num_logits_to_keep: Optional[int] = 0,
245243
) -> Union[Tuple, CausalLMOutputWithPast]:
246244
r"""
247245
Args:
@@ -290,7 +288,8 @@ def forward(
290288
)
291289

292290
# Cast to INT32 to avoid issue while running in ONNXRT
293-
hidden_states = filter_hidden_states(outputs[0], position_ids, num_logits_to_keep)
291+
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
292+
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
294293
if self.config.pretraining_tp > 1:
295294
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
296295
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]

QEfficient/transformers/models/modeling_auto.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig
2121
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig
2222
from QEfficient.utils import get_qpc_dir_path, load_hf_tokenizer
23-
from QEfficient.utils.constants import QEFF_MODELS_DIR, NUM_LOGITS_TO_KEEP
23+
from QEfficient.utils.constants import QEFF_MODELS_DIR
2424
from QEfficient.utils.logging_utils import logger
2525

2626
# Dictionary that defines the interface from transformers to be used underneath the QEFF interface
@@ -58,7 +58,7 @@ def __init__(self, model: nn.Module, pretrained_model_name_or_path: str, **kwarg
5858
self.model_card_name = self.pretrained_model_name_or_path
5959

6060
self.full_batch_size = kwargs.get("full_batch_size", None)
61-
self.num_logits_to_keep = kwargs.get("num_logits_to_keep", NUM_LOGITS_TO_KEEP)
61+
self.num_speculative_tokens = kwargs.get("num_speculative_tokens", None)
6262
self.is_dlm = kwargs.get("is_dlm", False)
6363
self.kwargs = kwargs
6464
self._tokenizer = None
@@ -105,7 +105,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
105105

106106
full_batch_size = kwargs.pop("full_batch_size", None)
107107

108-
num_logits_to_keep = kwargs.pop("num_logits_to_keep", NUM_LOGITS_TO_KEEP)
108+
num_speculative_tokens = kwargs.pop("num_speculative_tokens", None)
109109
is_dlm = kwargs.pop("is_dlm", False)
110110

111111
attn_implementation = kwargs.get("attn_implementation", None)
@@ -125,7 +125,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
125125
pretrained_model_name_or_path=pretrained_model_name_or_path,
126126
model_card_name=model_card_name,
127127
full_batch_size=full_batch_size,
128-
num_logits_to_keep=num_logits_to_keep,
128+
num_speculative_tokens=num_speculative_tokens,
129129
is_dlm=is_dlm,
130130
**kwargs,
131131
)
@@ -167,15 +167,15 @@ class QEFFAutoModelForCausalLM(QEFFTransformersBase):
167167

168168
def transform(
169169
self,
170-
num_logits_to_keep: Optional[int] = NUM_LOGITS_TO_KEEP,
170+
num_speculative_tokens: Optional[int] = None,
171171
is_dlm: bool = False,
172172
**kwargs):
173173
"""
174174
This method applies all relevant optimization transforms on the model and toggles the ``self.is_transformed`` attribute to True. If the model is already transformed, the method will simply return.
175175
Please note that this method does not require any input arguments."
176176
177177
``Optional`` Args:
178-
:num_logits_to_keep (int, optional): Number of speculative tokens, specified only for TLM SpD model.
178+
:num_speculative_tokens (int, optional): Number of speculative tokens, specified only for TLM SpD model.
179179
:is_dlm (bool): True if this is a DLM SpD model.
180180
181181
Returns:
@@ -202,11 +202,11 @@ def transform(
202202
if isinstance(self.model.config.quantization_config, QEffGPTQConfig):
203203
self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform)
204204

205-
if num_logits_to_keep is not None:
206-
if not isinstance(num_logits_to_keep, int) or num_logits_to_keep<2:
207-
ValueError("`num_logits_to_keep` arg should be an integer greater than 1.")
205+
if num_speculative_tokens is not None:
206+
if not isinstance(num_speculative_tokens, int) or num_speculative_tokens<2:
207+
ValueError("`num_speculative_tokens` arg should be an integer greater than 1.")
208208
if is_dlm:
209-
raise ValueError("`num_logits_to_keep` arg and `is_dlm` flag are mutually exclusive.")
209+
raise ValueError("`num_speculative_tokens` arg and `is_dlm` flag are mutually exclusive.")
210210
self._pytorch_transforms.append(SpDTransform)
211211

212212
for transform in self._pytorch_transforms:
@@ -239,7 +239,7 @@ def export(self) -> str:
239239
model_kv=self,
240240
tokenizer=self.tokenizer,
241241
full_batch_size=self.full_batch_size,
242-
num_logits_to_keep=self.num_logits_to_keep,
242+
num_speculative_tokens=self.num_speculative_tokens,
243243
)
244244
self.onnx_path = onnx_model_path
245245

@@ -311,8 +311,8 @@ def compile(
311311
mxfp6=mxfp6,
312312
mxint8=mxint8,
313313
full_batch_size=self.full_batch_size,
314-
num_logits_to_keep=self.num_logits_to_keep,
315-
is_dlm=getattr(self.model, "is_dlm", False),
314+
num_speculative_tokens=self.num_speculative_tokens,
315+
is_dlm=self.is_dlm,
316316
)
317317
self.qpc_path = qpc_dir_path
318318
return self.qpc_path
@@ -375,8 +375,8 @@ def export_and_compile(
375375
mxfp6=mxfp6,
376376
mxint8=mxint8,
377377
full_batch_size=full_batch_size,
378-
num_logits_to_keep=self.num_logits_to_keep,
379-
is_dlm=getattr(self.model, "is_dlm", False),
378+
num_speculative_tokens=self.num_speculative_tokens,
379+
is_dlm=self.is_dlm,
380380
)
381381
return self.qpc_path
382382

QEfficient/transformers/models/spd/modeling_tlm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from transformers.cache_utils import Cache
1313
from transformers.modeling_outputs import CausalLMOutputWithPast
1414

15-
from QEfficient.utils.constants import NUM_LOGITS_TO_KEEP
1615
from QEfficient.transformers.modeling_spd_utils import filter_hidden_states
1716

1817
def tlm_forward(
@@ -29,7 +28,7 @@ def tlm_forward(
2928
output_hidden_states: Optional[bool] = None,
3029
return_dict: Optional[bool] = None,
3130
cache_position: Optional[torch.LongTensor] = None,
32-
#num_logits_to_keep: Optional[torch.LongTensor] = None,
31+
#num_logits_to_keep: Optional[torch.LongTensor] = None, # explicit passing is not currently supported
3332
) -> Union[Tuple, CausalLMOutputWithPast]:
3433
r"""
3534
Args:

QEfficient/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import os
99

10-
NUM_LOGITS_TO_KEEP = None
1110
UTILS_DIR = os.path.dirname(os.path.abspath(__file__))
1211
QEFF_DIR = os.path.dirname(UTILS_DIR)
1312
ROOT_DIR = os.path.dirname(QEFF_DIR)

0 commit comments

Comments
 (0)