Skip to content

Commit 32ce801

Browse files
committed
lint fix
1 parent fa058b7 commit 32ce801

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

QEfficient/utils/generate_inputs.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@
1515

1616
class InputHandler:
1717
def __init__(
18-
self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int]
18+
self,
19+
batch_size,
20+
tokenizer,
21+
config,
22+
prompt,
23+
prompt_len,
24+
ctx_len,
25+
full_batch_size,
26+
num_logits_to_keep: Optional[int],
1927
):
2028
"""
2129
Initialization
@@ -28,8 +36,8 @@ def __init__(
2836
:prompt_len (int): Prompt length for the model to compile.
2937
:ctx_len (int): Maximum context length to compile the model.
3038
:full_batch_size (int): Continuous batching batch size
31-
:num_logits_to_keep (Optional[int]):
32-
Calculate logits for the last valid `num_logits_to_keep` tokens.
39+
:num_logits_to_keep (Optional[int]):
40+
Calculate logits for the last valid `num_logits_to_keep` tokens.
3341
Only last token logits are needed for generation, and calculating them only for that
3442
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
3543
"""
@@ -116,12 +124,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
116124
if self.full_batch_size:
117125
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
118126
batch_index = torch.arange(1).view(-1, 1)
119-
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
127+
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
120128
input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id)
121129
input_ids[batch_index.view(-1)] = batch_idx_input_ids
122130

123131
position_ids = torch.full((self.full_batch_size, decode_len), 0)
124-
batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1)
132+
batch_idx_position_ids = torch.arange(decode_len).view(1, -1) + (
133+
inputs["position_ids"].max(1, keepdim=True).values + 1
134+
)
125135
position_ids[batch_index.view(-1)] = batch_idx_position_ids
126136

127137
updated_inputs["input_ids"] = input_ids
@@ -130,11 +140,13 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
130140

131141
else:
132142
if self.num_logits_to_keep is not None:
133-
input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep]
143+
input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep]
134144
batch_size = input_ids.size(0)
135-
position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
145+
position_ids = (
146+
torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
147+
)
136148
else:
137-
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
149+
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
138150
position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1
139151
updated_inputs["input_ids"] = input_ids
140152
updated_inputs["position_ids"] = position_ids

0 commit comments

Comments
 (0)