diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index b63fe6a..110d32c 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -34,11 +34,10 @@ "if set, then save the result to the given file name", ) flags.DEFINE_bool( - "internal_use_local_tokenizer", - 0, - "Use local tokenizer if set to True" + "internal_use_local_tokenizer", 0, "Use local tokenizer if set to True" ) + def shard_weights(env, weights, weight_shardings): """Shard weights according to weight_shardings""" sharded = {} diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index a8a833a..d1e21d9 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -49,6 +49,11 @@ flags.DEFINE_bool( "quantize_kv_cache", None, "defaults to the same value as quantize_weights" ) +flags.DEFINE_multi_string( + "quantize_exclude_layers", + None, + "List of layer names to exclude from quantization", +) _VALID_QUANTIZATION_TYPE = { "int8_per_channel", @@ -178,6 +183,7 @@ def create_quantization_config_from_flags(): config.is_blockwise_weight = "blockwise" in quantize_type config.enable_activation_quantization = FLAGS.quantize_activation + config.exclude_layers = FLAGS.quantize_exclude_layers config.enable_kv_quantization = ( FLAGS.quantize_kv_cache if FLAGS.quantize_kv_cache is not None diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 4917705..6124a76 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import Tuple +from typing import List, Tuple, Union import jax import jax.numpy as jnp @@ -37,6 +37,7 @@ class QuantizationConfig: enable_activation_quantization: bool = False enable_kv_quantization: bool = False + exclude_layers: Union[None, List[str]] = None @dataclasses.dataclass diff --git a/jetstream_pt/quantize_model.py b/jetstream_pt/quantize_model.py index 0f37486..79ff278 100644 --- a/jetstream_pt/quantize_model.py +++ b/jetstream_pt/quantize_model.py @@ -1,5 +1,5 @@ import torch -from absl import flags +from .environment import QuantizationConfig from .layers import ( create_quantized_from_nn_linear, create_quantized_from_nn_embedding, @@ -8,24 +8,19 @@ ) -_QUANTIZE_EMBEDDING = flags.DEFINE_bool( - "internal_quantize_embedding_layer", - True, - "Whether to quantize embedding layer or not. Defaults to true", -) - - -def quantize_model(float_model, config): +def quantize_model(float_model, config: QuantizationConfig): """Apply quantization to linear layers.""" def quantize_nn_mod(float_model): for name, mod in float_model.named_modules(): new_mod = None + if config.exclude_layers and name in config.exclude_layers: + continue if hasattr(mod, "get_quantized_version"): new_mod = mod.get_quantized_version() elif isinstance(mod, torch.nn.Linear): new_mod = create_quantized_from_nn_linear(mod, config) - elif isinstance(mod, torch.nn.Embedding) and _QUANTIZE_EMBEDDING.value: + elif isinstance(mod, torch.nn.Embedding): new_mod = create_quantized_from_nn_embedding(mod, config) if new_mod: