Skip to content

Add model warmup and jax compilation cache flags #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

vivianrwu
Copy link
Collaborator

Adds the flags for model warmup support in jetstream-pytorch, and add jax compilation cache flags

Pass --enable_model_warmup=True

        args:
        - --size=8b
        - --model_name=llama-3
        - --batch_size=128
        - --max_cache_length=2048
        - --quantize_weights=False
        - --quantize_kv_cache=False
        - --tokenizer_path=/models/llama3-8b/final/bf16/tokenizer.model
        - --checkpoint_path=/models/llama3-8b/final/bf16/model.safetensors
        - --enable_model_warmup=True

Logs that indicate model warmup is occurring:

2024-09-27 19:13:55,447 - root - INFO - ---------Prefill engine 0 compiled for prefill length 256.---------
I0927 19:13:55.447131 134077799315008 warmup_utils.py:108] ---------Prefill engine 0 compiled for prefill length 256.---------
I0927 19:13:55.618359 134077990557248 warmup_utils.py:108] ---------Prefill engine 0 compiled for prefill length 64.---------
2024-09-27 19:13:55,618 - root - INFO - ---------Prefill engine 0 compiled for prefill length 64.---------
2024-09-27 19:13:56,050 - root - INFO - ---------Prefill engine 0 compiled for prefill length 128.---------
I0927 19:13:56.050476 134077955900992 warmup_utils.py:108] ---------Prefill engine 0 compiled for prefill length 128.-

After ssh into the container:

root@jetstream-pytorch-server-db6f74545-m8gxd:~/jax_cache# ls
jit_generate_impl-HASH
jit_insert-HASH
jit_insert-HASH
...
jit_prefill-HASH
jit_prefill-HASH

Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it! Look good to me.

@qihqi qihqi merged commit f2e5181 into AI-Hypercomputer:main Oct 2, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants