Skip to content

Commit fa058b7

Browse files
committed
added back mxint8 export and compilation
Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent abd04e4 commit fa058b7

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

QEfficient/utils/generate_inputs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
116116
if self.full_batch_size:
117117
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
118118
batch_index = torch.arange(1).view(-1, 1)
119-
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2)
119+
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
120120
input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id)
121121
input_ids[batch_index.view(-1)] = batch_idx_input_ids
122+
122123
position_ids = torch.full((self.full_batch_size, decode_len), 0)
123124
batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1)
124125
position_ids[batch_index.view(-1)] = batch_idx_position_ids
126+
125127
updated_inputs["input_ids"] = input_ids
126128
updated_inputs["position_ids"] = position_ids
127129
updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)
@@ -132,7 +134,7 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
132134
batch_size = input_ids.size(0)
133135
position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
134136
else:
135-
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
137+
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
136138
position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1
137139
updated_inputs["input_ids"] = input_ids
138140
updated_inputs["position_ids"] = position_ids

tests/spd/test_tlm_dlm_export_and_compile.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@
1616

1717
configs = [
1818
pytest.param(
19-
[0], # device_group
20-
2, # num_speculative_tokens
21-
32, # prompt_len
22-
128, # ctx_len
23-
1, # prefill_bsz
24-
8, # full_batch_size
25-
"JackFram/llama-68m", # model_name
19+
[0], # device_group
20+
2, # num_speculative_tokens
21+
32, # prompt_len
22+
128, # ctx_len
23+
1, # prefill_bsz
24+
8, # full_batch_size
25+
"JackFram/llama-68m", # model_name
2626
id="CB llama",
2727
),
2828
pytest.param(
29-
[0], # device_group
30-
2, # num_speculative_tokens
31-
32, # prompt_len
32-
128, # ctx_len
33-
1, # prefill_bsz
34-
None, # full_batch_size
35-
"JackFram/llama-68m", # model_name
29+
[0], # device_group
30+
2, # num_speculative_tokens
31+
32, # prompt_len
32+
128, # ctx_len
33+
1, # prefill_bsz
34+
None, # full_batch_size
35+
"JackFram/llama-68m", # model_name
3636
id="non-CB llama",
3737
),
3838
]
@@ -63,7 +63,7 @@ def test_llama_tlm_logit_dims(
6363
prompt_len=prompt_len,
6464
ctx_len=ctx_len,
6565
mxfp6=True,
66-
# mxint8=True,
66+
mxint8=True,
6767
full_batch_size=full_batch_size,
6868
)
6969

@@ -126,7 +126,7 @@ def test_llama_dlm_logit_dims(
126126
prompt_len=prompt_len,
127127
ctx_len=ctx_len,
128128
mxfp6=True,
129-
# mxint8=True,
129+
mxint8=True,
130130
full_batch_size=full_batch_size,
131131
)
132132

0 commit comments

Comments
 (0)