Skip to content

Commit f32f9c6

Browse files
authored
Make jpt the default cli - remove other entry point scripts (#188)
Make cli the default cli
1 parent d84ae15 commit f32f9c6

File tree

13 files changed

+68
-785
lines changed

13 files changed

+68
-785
lines changed

.github/workflows/unit_tests.yaml

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,28 +79,4 @@ jobs:
7979
JAX_PLATFORMS=cpu coverage run -m unittest -v
8080
- name: Create test coverage report
8181
run: |
82-
coverage report -m
83-
84-
interactive:
85-
name: "jetstream_pt run interactive"
86-
strategy:
87-
matrix:
88-
os: [ubuntu-20.04]
89-
python-version: ['3.10']
90-
runs-on: ${{ matrix.os }}
91-
steps:
92-
- name: Checkout
93-
uses: actions/checkout@v4
94-
- name: Setup Python
95-
uses: actions/setup-python@v4
96-
with:
97-
python-version: ${{ matrix.python-version }}
98-
- name: Install Dependencies
99-
run: |
100-
source install_everything.sh
101-
- name: Run interactive (bf16)
102-
run: |
103-
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
104-
- name: Run interactive (int8)
105-
run: |
106-
JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1
82+
coverage report -m

benchmarks/prefill_offline.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

benchmarks/run_offline.py

Lines changed: 0 additions & 157 deletions
This file was deleted.

jetstream_pt/cli.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import List
23
import random
34
import sys
@@ -39,10 +40,8 @@ def shard_weights(env, weights, weight_shardings):
3940
sharded = {}
4041
for key, val in weights.items():
4142
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
42-
print("SHARDING", key, sharding)
4343
with jax.default_device(jax.devices("cpu")[0]):
4444
arr = torch_xla2.tensor.t2j(val)
45-
4645
arr = jax.device_put(arr, sharding)
4746
sharded[key] = torchjax.to_torch(arr)
4847
return sharded
@@ -57,17 +56,16 @@ def create_engine(devices):
5756
FLAGS.override_batch_size,
5857
FLAGS.max_input_length,
5958
FLAGS.max_output_length,
60-
quant_config.enable_weight_quantization,
6159
)
6260
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
6361
env = environment.JetEngineEnvironment(env_data)
6462
env.hf_tokenizer = tokenizer
6563
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
64+
# NOTE: this is assigned later because, the model should be constructed
65+
# as a float model first then quantized
66+
env.quant_config = quant_config
6667
if quant_config.enable_weight_quantization:
6768
quantize_model.quantize_model(model, quant_config)
68-
print("====== model =======")
69-
print(model)
70-
7169
weight_shardings = model.get_sharding_annotations()
7270
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)
7371
env_data.quant_config = quant_config
@@ -202,7 +200,7 @@ def interactive():
202200
"<s>[INST] <<SYS>>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<</SYS>>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]",
203201
]
204202
for prompt in prompts:
205-
slot = random.randint(0, FLAGS.batch_size - 1)
203+
slot = random.randint(0, FLAGS.override_batch_size - 1)
206204
tokens, true_length = tokenizer.encode(prompt)
207205

208206
print(f"---- Input prompts are: {prompt}")
@@ -330,10 +328,10 @@ def benchmark_offline():
330328
decode_time_ms = sum(dec_times[2:]) * 1000 / 8
331329

332330
largest_prefill = max(prefill_times.items())
333-
print("MAX tokens:", FLAGS.batch_size / avg_decode_times)
331+
print("MAX tokens:", FLAGS.override_batch_size / avg_decode_times)
334332

335-
time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / (
336-
FLAGS.batch_size * largest_prefill[1]
333+
time2 = (FLAGS.override_batch_size * FLAGS.max_decode_length) / (
334+
FLAGS.override_batch_size * largest_prefill[1]
337335
+ FLAGS.max_decode_length * avg_decode_times
338336
)
339337
print("MAX tokens 2:", time2)
@@ -351,6 +349,8 @@ def main():
351349

352350
def main_real(argv):
353351
"""Entry point"""
352+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
353+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
354354
if len(argv) < 2:
355355
print("Invalid arguments. please specify 'list' or 'serve'")
356356

0 commit comments

Comments
 (0)