Skip to content

feat: add quantize exclude layer flag #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
"if set, then save the result to the given file name",
)
flags.DEFINE_bool(
"internal_use_local_tokenizer",
0,
"Use local tokenizer if set to True"
"internal_use_local_tokenizer", 0, "Use local tokenizer if set to True"
)


def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
sharded = {}
Expand Down
6 changes: 6 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
flags.DEFINE_bool(
"quantize_kv_cache", None, "defaults to the same value as quantize_weights"
)
flags.DEFINE_multi_string(
"quantize_exclude_layers",
None,
"List of layer names to exclude from quantization",
)

_VALID_QUANTIZATION_TYPE = {
"int8_per_channel",
Expand Down Expand Up @@ -178,6 +183,7 @@ def create_quantization_config_from_flags():
config.is_blockwise_weight = "blockwise" in quantize_type

config.enable_activation_quantization = FLAGS.quantize_activation
config.exclude_layers = FLAGS.quantize_exclude_layers
config.enable_kv_quantization = (
FLAGS.quantize_kv_cache
if FLAGS.quantize_kv_cache is not None
Expand Down
3 changes: 2 additions & 1 deletion jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import Tuple
from typing import List, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -37,6 +37,7 @@ class QuantizationConfig:

enable_activation_quantization: bool = False
enable_kv_quantization: bool = False
exclude_layers: Union[None, List[str]] = None


@dataclasses.dataclass
Expand Down
15 changes: 5 additions & 10 deletions jetstream_pt/quantize_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from absl import flags
from .environment import QuantizationConfig
from .layers import (
create_quantized_from_nn_linear,
create_quantized_from_nn_embedding,
Expand All @@ -8,24 +8,19 @@
)


_QUANTIZE_EMBEDDING = flags.DEFINE_bool(
"internal_quantize_embedding_layer",
True,
"Whether to quantize embedding layer or not. Defaults to true",
)


def quantize_model(float_model, config):
def quantize_model(float_model, config: QuantizationConfig):
"""Apply quantization to linear layers."""

def quantize_nn_mod(float_model):
for name, mod in float_model.named_modules():
new_mod = None
if config.exclude_layers and name in config.exclude_layers:
continue
if hasattr(mod, "get_quantized_version"):
new_mod = mod.get_quantized_version()
elif isinstance(mod, torch.nn.Linear):
new_mod = create_quantized_from_nn_linear(mod, config)
elif isinstance(mod, torch.nn.Embedding) and _QUANTIZE_EMBEDDING.value:
elif isinstance(mod, torch.nn.Embedding):
new_mod = create_quantized_from_nn_embedding(mod, config)

if new_mod:
Expand Down
Loading