15
15
16
16
class InputHandler :
17
17
def __init__ (
18
- self , batch_size , tokenizer , config , prompt , prompt_len , ctx_len , full_batch_size , num_logits_to_keep : Optional [int ]
18
+ self ,
19
+ batch_size ,
20
+ tokenizer ,
21
+ config ,
22
+ prompt ,
23
+ prompt_len ,
24
+ ctx_len ,
25
+ full_batch_size ,
26
+ num_logits_to_keep : Optional [int ],
19
27
):
20
28
"""
21
29
Initialization
@@ -28,8 +36,8 @@ def __init__(
28
36
:prompt_len (int): Prompt length for the model to compile.
29
37
:ctx_len (int): Maximum context length to compile the model.
30
38
:full_batch_size (int): Continuous batching batch size
31
- :num_logits_to_keep (Optional[int]):
32
- Calculate logits for the last valid `num_logits_to_keep` tokens.
39
+ :num_logits_to_keep (Optional[int]):
40
+ Calculate logits for the last valid `num_logits_to_keep` tokens.
33
41
Only last token logits are needed for generation, and calculating them only for that
34
42
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
35
43
"""
@@ -116,12 +124,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
116
124
if self .full_batch_size :
117
125
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
118
126
batch_index = torch .arange (1 ).view (- 1 , 1 )
119
- batch_idx_input_ids = pt_outputs .logits .detach ().argmax (2 ) # shape: [batch_size, num_logits_to_keep]
127
+ batch_idx_input_ids = pt_outputs .logits .detach ().argmax (2 ) # shape: [batch_size, num_logits_to_keep]
120
128
input_ids = torch .full ((self .full_batch_size , decode_len ), self .tokenizer .pad_token_id )
121
129
input_ids [batch_index .view (- 1 )] = batch_idx_input_ids
122
130
123
131
position_ids = torch .full ((self .full_batch_size , decode_len ), 0 )
124
- batch_idx_position_ids = torch .arange (decode_len ).view (1 ,- 1 ) + (inputs ["position_ids" ].max (1 , keepdim = True ).values + 1 )
132
+ batch_idx_position_ids = torch .arange (decode_len ).view (1 , - 1 ) + (
133
+ inputs ["position_ids" ].max (1 , keepdim = True ).values + 1
134
+ )
125
135
position_ids [batch_index .view (- 1 )] = batch_idx_position_ids
126
136
127
137
updated_inputs ["input_ids" ] = input_ids
@@ -130,11 +140,13 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
130
140
131
141
else :
132
142
if self .num_logits_to_keep is not None :
133
- input_ids = pt_outputs ["logits" ].argmax (- 1 ) # shape: [batch_size, num_logits_to_keep]
143
+ input_ids = pt_outputs ["logits" ].argmax (- 1 ) # shape: [batch_size, num_logits_to_keep]
134
144
batch_size = input_ids .size (0 )
135
- position_ids = torch .arange (self .num_logits_to_keep ).view (1 , self .num_logits_to_keep ).repeat (batch_size , 1 )
145
+ position_ids = (
146
+ torch .arange (self .num_logits_to_keep ).view (1 , self .num_logits_to_keep ).repeat (batch_size , 1 )
147
+ )
136
148
else :
137
- input_ids = pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ) # shape: [batch_size, 1]
149
+ input_ids = pt_outputs ["logits" ].argmax (- 1 ).reshape (- 1 , 1 ) # shape: [batch_size, 1]
138
150
position_ids = inputs ["position_ids" ].max (1 , keepdim = True ).values + 1
139
151
updated_inputs ["input_ids" ] = input_ids
140
152
updated_inputs ["position_ids" ] = position_ids
0 commit comments