diff --git a/benchmarks/mixtral_offline.sh b/benchmarks/mixtral_offline.sh new file mode 100644 index 00000000..ea64195f --- /dev/null +++ b/benchmarks/mixtral_offline.sh @@ -0,0 +1,20 @@ +CACHE_LENGTH=1024 +INPUT_SIZE=512 +OUTPUT_SIZE=1024 +BATCH_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +pushd .. +python -m benchmarks.run_offline \ + --model_name=mixtral \ + --batch_size=$BATCH_SIZE \ + --max_cache_length=$CACHE_LENGTH \ + --max_decode_length=$OUTPUT_SIZE \ + --context_length=$INPUT_SIZE \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 \ + --profiling_output=/mnt/disks/hanq/mixtral-profiles +popd \ No newline at end of file diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index daeafac7..2abda049 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -120,10 +120,20 @@ def main(argv): jax.profiler.stop_trace() print("prefill ", prefill_times) - print("decode", sum(dec_times) / 10) + avg_decode_times = sum(dec_times[2:]) / len(dec_times[2:]) + print("decode", avg_decode_times) prefill_times_ms = {k: v * 1000 for k, v in prefill_times.items()} - decode_time_ms = sum(dec_times) * 1000 / 10 / FLAGS.batch_size + decode_time_ms = sum(dec_times[2:]) * 1000 / 8 + + largest_prefill = max(prefill_times.items()) + print("MAX tokens:", FLAGS.batch_size / avg_decode_times) + + time2 = (FLAGS.batch_size * FLAGS.max_decode_length) / ( + FLAGS.batch_size * largest_prefill[1] + + FLAGS.max_decode_length * avg_decode_times + ) + print("MAX tokens 2:", time2) sharegpt_path = FLAGS.sharegpt_path if sharegpt_path: diff --git a/mlperf/README.md b/mlperf/README.md new file mode 100644 index 00000000..b3322c1c --- /dev/null +++ b/mlperf/README.md @@ -0,0 +1,31 @@ +# Run MLPerf tests + +NOTE: currently only tried with mixtral; +and only tried with offline benchmark + +# How to run + +### 1. Install + +``` +./install.sh +``` + +### 2. Start server + +``` +./start_server.sh +``` + +### 3. Warm up the server + +``` +python warmup.py +``` + +### 4. Run the benchmark, now it runs offline mode + +``` +./benchmark_run.sh +``` + diff --git a/mlperf/backend.py b/mlperf/backend.py new file mode 100644 index 00000000..806eb727 --- /dev/null +++ b/mlperf/backend.py @@ -0,0 +1,352 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""mlperf loadgen interface for LLama2.""" +import math +import array +import concurrent.futures +import dataclasses +import json +import logging +from operator import itemgetter # pylint: disable=g-importing-member +import time +from typing import List, Optional, Any + +import numpy as np + +from . import dataset + +import mlperf_loadgen as lg + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc + +from transformers import AutoTokenizer + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("backend.py") + + +@dataclasses.dataclass +class WarmupSample: + id: int + index: int + + +@dataclasses.dataclass +class StreamResponse: + result: str = None + + +def _find_interesting_samples(max_length, dataset, encoder): + start_len = 1 + lengths = [len(encoder.encode(data)) for data in dataset] + min_length = min(lengths) + max_length = min(max(lengths), max_length) + start_len = 2 ** int(math.log2(min_length)) + while start_len * 2 < max_length: + for i, data in enumerate(dataset): + length = len(encoder.encode(data)) + if start_len < length <= start_len * 2: + log.info(f"Warmup sample: id={i} of length={length}") + yield i + break # for + else: + log.info( + f"DISCARD Warmup sample: id={i} of length={length} for {start_len}" + ) + start_len *= 2 + + +class ThreadedLMClient: + """Holds a thread pool and a loadgen client for LM inference.""" + + _thread_pool: concurrent.futures.ThreadPoolExecutor + _dataset: dataset.Dataset + _futures = List[concurrent.futures.Future] + + def __init__( + self, + is_stream: bool, + num_threads: int, + api_url: str, + dataset_object: dataset.Dataset, + input_mode: str, + output_mode: str, + tokenizer: Optional[AutoTokenizer] = None, + max_output_len: int = 1024, + log_interval: int = 1000, + ): + log.info(f"Initiating {self.__class__.__name__} ...") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + if self._input_mode == "text" or self._output_mode == "text": + assert tokenizer is not None + self._tokenizer = tokenizer + self._max_output_len = max_output_len + + self._log_interval = log_interval + + self._thread_pool = concurrent.futures.ThreadPoolExecutor(num_threads) + self._api_url = api_url + self._dataset = dataset_object + self._futures = [] + self.pred_outputs = {} + self._resp_cnt = 0 + + # Post processing stop sequence for Mixtral MXBP dataset + self._stop_seq: List[int] = [13, 13940, 28832, 13] + self._stop_seq_len = len(self._stop_seq) + + log.info("Creating grpc channel with api_url {}".format(api_url)) + options = [("grpc.keepalive_timeout_ms", 10000)] + self._grpc_channel = grpc.insecure_channel(api_url, options=options) + + @property + def tokenizer(self): + return self._tokenizer + + def _log_resp_cnt(self): + self._resp_cnt += 1 + if self._resp_cnt % self._log_interval == 0: + log.info("Completed %d queries", self._resp_cnt) + + def post_process_response(self, response_tokens): + for i in range(self._stop_seq_len, len(response_tokens)): + if response_tokens[i - self._stop_seq_len : i] == self._stop_seq: + # log.info(f"Post process found stop seq: {response_tokens}") + return response_tokens[:i] + + # log.info(f"Post process no-op for {response_tokens}") + return response_tokens + + def process_single_sample_async(self, query_sample, warmup): + """Executes a single query and marks responses complete asynchronously. + + Args: + query_sample: Single prompt + warmup: Indicates that this is a warmup request. + """ + future = self._thread_pool.submit( + self._process_sample, query_sample, warmup + ) + self._futures.append(future) + + def flush(self): + concurrent.futures.wait(self._futures) + self._futures = [] + + def _grpc_request(self, request, sample, warmup): + """Send grpc synchronous request since the current grpc server is sync.""" + stub = jetstream_pb2_grpc.OrchestratorStub(self._grpc_channel) + token_list = [] + ttft = 0 + start_time = time.perf_counter() + response = stub.Decode(request) + for resp in response: + if not warmup and self._is_stream and ttft == 0: + # TTFT for online mode + ttft = time.perf_counter() - start_time + log.info("TTFT {}ms".format(ttft * 1000)) + response_token_ids = resp.stream_content.samples[0].token_ids + assert len(response_token_ids) == 1 + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + first_token_response = lg.QuerySampleResponse( + sample.id, response_info[0], response_info[1] + ) + lg.FirstTokenComplete([first_token_response]) + log.info("mark first token complete") + token_list.extend(resp.stream_content.samples[0].token_ids) + return token_list + + def _process_sample(self, sample, warmup): + """Processes a single sample.""" + sample_data = self._dataset.inputs[sample.index] + if self._input_mode == "text": + token_ids = self._tokenizer.encode(sample_data) + else: + assert self._input_mode == "tokenized" + token_ids = [int(token_id_str) for token_id_str in sample_data.split(",")] + + request = jetstream_pb2.DecodeRequest( + session_cache="", + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + priority=0, + max_tokens=self._max_output_len, + ) + generated_token_list = self._grpc_request(request, sample, warmup) + if not warmup: + try: + dataset_name = self._dataset.input_datasets[sample.index] + if dataset_name == "MBXP": + response_token_ids = self.post_process_response(generated_token_list) + else: + response_token_ids = generated_token_list + except Exception as e: + log.info(f"Error - {e}") + response_token_ids = generated_token_list + n_tokens = len(response_token_ids) + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + response_data = response_info[0] + response_size = response_info[1] * response_array.itemsize + query_sample_response = lg.QuerySampleResponse( + sample.id, response_data, response_size, n_tokens + ) + lg.QuerySamplesComplete([query_sample_response]) + # log.info(f"mark query as complete for - {dataset_name}") + log.info(f"mark query as complete") + pred_output = self._tokenizer.decode(response_token_ids) + self.pred_outputs[sample.index] = pred_output + self._log_resp_cnt() + + +class SUT: + """SUT.""" + + def __init__( + self, + scenario, + api_url, + is_stream, + input_mode, + output_mode, + max_output_len, + dataset_path, + total_sample_count, + tokenizer_path=None, + perf_count_override=None, + num_client_threads=200, + log_interval=1000, + batch_size_exp=5, + pred_outputs_log_path=None, + ): + log.info(f"Starting {scenario} SUT with {api_url}.") + self._is_stream = is_stream + self._input_mode = dataset.validate_sample_mode(input_mode) + self._output_mode = dataset.validate_sample_mode(output_mode) + assert tokenizer_path is not None + self._tokenizer = self.load_tokenizer(tokenizer_path) + self._max_output_len = max_output_len + self._api_url = api_url + self._dataset_path = dataset_path + self._total_sample_count = total_sample_count + self._perf_count_override = perf_count_override + self._num_client_threads = num_client_threads + self._log_interval = log_interval + self._batch_size_exp = batch_size_exp + self._pred_outputs_log_path = pred_outputs_log_path + + log.info("Loading Dataset ... ") + self.dataset = dataset.Dataset( + dataset_path=self._dataset_path, + input_mode=self._input_mode, + total_sample_count=self._total_sample_count, + perf_count_override=self._perf_count_override, + ) + + client_cls = ThreadedLMClient + self._client = client_cls( + is_stream=self._is_stream, + num_threads=self._num_client_threads, + api_url=self._api_url, + dataset_object=self.dataset, + input_mode=self._input_mode, + output_mode=self._output_mode, + tokenizer=self._tokenizer, + max_output_len=self._max_output_len, + log_interval=self._log_interval, + ) + + self.qsl = lg.ConstructQSL( + self.dataset.total_sample_count, + self.dataset.perf_count, + self.dataset.LoadSamplesToRam, + self.dataset.UnloadSamplesFromRam, + ) + self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) + + def load_tokenizer( + self, tokenizer_path: Optional[str] = None + ) -> Optional[AutoTokenizer]: + """Returns tokenizer""" + if tokenizer_path is not None: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + model_max_length=1024, + padding_side="left", + use_fast=True, + ) + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _sort_issue_queries(self, query_samples): + """Issue queries.""" + query_samples_with_length = [] + for query_sample in query_samples: + query_sample_token_length = self.dataset.inputs_with_token_lengths[ + query_sample.index + ][1] + query_samples_with_length.append( + (query_sample_token_length, query_sample) + ) + sorted_query_samples_with_length = sorted( + query_samples_with_length, key=itemgetter(0) + ) + sorted_query_samples = [x[1] for x in sorted_query_samples_with_length] + return sorted_query_samples + + def issue_queries(self, query_samples): + """Issue queries.""" + num_query_samples = len(query_samples) + if num_query_samples > 1: + log.info(f"Issuing {num_query_samples} queries. ") + query_samples = self._sort_issue_queries(query_samples) + for query_sample in query_samples: + self._client.process_single_sample_async(query_sample, False) + + def flush_queries(self): + """Flush queries.""" + log.info("Loadgen has completed issuing queries... ") + self._client.flush() + + if self._pred_outputs_log_path is not None: + + pred_outputs = [] + for idx, x in self._client.pred_outputs.items(): + pred_output = { + "qsl_idx": idx, + "intput": self._client._dataset.inputs[idx], + "data": x, + } + pred_outputs.append(pred_output) + log.info(f"Generated {len(pred_outputs)} prediction outputs") + + if pred_outputs: + self.accuracy_log = open(self._pred_outputs_log_path, "w") + self.accuracy_log.write(json.dumps(pred_outputs)) + self.accuracy_log.flush() + self.accuracy_log.close() + log.info("Dumpped prediction outputs to accuracy log... ") + + def __del__(self): + print("Finished destroying SUT.") diff --git a/mlperf/benchmark_run.sh b/mlperf/benchmark_run.sh new file mode 100755 index 00000000..946c301a --- /dev/null +++ b/mlperf/benchmark_run.sh @@ -0,0 +1,32 @@ +BASEDIR=mlperf +API_URL=0.0.0.0:9000 +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=1000 +DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" + +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +pushd .. +python -m mlperf.main \ + --api-url ${API_URL} \ + --scenario Offline \ + --input-mode tokenized \ + --output-mode tokenized \ + --log-pred-outputs \ + --mlperf-conf $BASEDIR/mlperf.conf \ + --user-conf ${USER_CONFIG} \ + --audit-conf no-audit \ + --total-sample-count ${TOTAL_SAMPLE_COUNT} \ + --dataset-path ${DATASET_PATH} \ + --tokenizer-path ${TOKENIZER_PATH} \ + --log-interval 1000 \ + --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log +popd \ No newline at end of file diff --git a/mlperf/dataset.py b/mlperf/dataset.py new file mode 100644 index 00000000..373bbc49 --- /dev/null +++ b/mlperf/dataset.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import pandas as pd + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("dataset.py") + + +class Dataset: + + def __init__( + self, + dataset_path: str, + input_mode: str, + total_sample_count: int = 15000, + perf_count_override: int = None, + ): + if not os.path.isfile(dataset_path): + log.warn( + "Processed pickle file {} not found. Please check that the path is correct".format( + dataset_path + ) + ) + self.dataset_path = dataset_path + + self._input_mode = validate_sample_mode(input_mode) + self.load_processed_dataset() + + self.total_sample_count = min(len(self.input_ids_strs), total_sample_count) + self.perf_count = perf_count_override or self.total_sample_count + + @property + def input_ids_strs(self): + return self._input_ids_strs + + @property + def input_texts(self): + return self._input_texts + + @property + def input_token_lengths(self): + return self._input_token_lengths + + @property + def inputs(self): + return self._inputs + + @property + def inputs_with_token_lengths(self): + return self._inputs_with_token_lengths + + @property + def input_datasets(self): + return self._input_datasets + + def load_processed_dataset(self): + processed_data = pd.read_pickle(self.dataset_path) + # processed_data = processed_data[processed_data["dataset"] == "MBXP"] + # processed_data = processed_data.reset_index(drop=True) + + self._input_ids_strs = [] + for input_ids in processed_data["tok_input"]: + input_ids_str = ",".join([str(input_id) for input_id in input_ids]) + self._input_ids_strs.append(input_ids_str) + + self._input_texts = [] + for input_text in processed_data["input"]: + self._input_texts.append(input_text) + + self._input_token_lengths = [] + for token_length in processed_data["tok_input_len"]: + self._input_token_lengths.append(token_length) + + log.info(f"input_mode is {self._input_mode}") + self._inputs = ( + self._input_ids_strs + if self._input_mode == "tokenized" + else self._input_texts + ) + log.info(f"example sample input is {self._inputs[0]}") + self._inputs_with_token_lengths = [ + (input_ids_str_or_input_text, token_length) + for input_ids_str_or_input_text, token_length in zip( + self._inputs, self._input_token_lengths + ) + ] + + self._input_datasets = [] + for dataset in processed_data["dataset"]: + self._input_datasets.append(dataset) + log.info( + f"example sample input dataset is {self._input_datasets[0]} and total {len(self._input_datasets)}" + ) + + def LoadSamplesToRam(self, sample_list): + pass + + def UnloadSamplesFromRam(self, sample_list): + pass + + def __del__(self): + pass + + +SAMPLE_MODE_CHOICES = ["tokenized", "text"] + + +def validate_sample_mode(sample_mode: str) -> str: + if sample_mode not in SAMPLE_MODE_CHOICES: + raise ValueError( + "The sample_mode should be set to either `tokenized` or `text`." + ) + return sample_mode diff --git a/mlperf/install.sh b/mlperf/install.sh new file mode 100644 index 00000000..3a8f037b --- /dev/null +++ b/mlperf/install.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +DATA_DISK_DIR=data + +mkdir -p $DATA_DISK_DIR + +pip install -U "huggingface_hub[cli]" +pip install \ + transformers \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +# install loadgen +pip install mlperf-loadgen + + +pushd $DATA_DISK_DIR + +# model weights +gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive +# NOTE: uncomment one so you dont download too much weights to your box +# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive + +# Get mixtral data +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl + +# Get llama70b data +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + processed-calibration-data.pkl +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ + processed-data.pkl +popd diff --git a/mlperf/main.py b/mlperf/main.py new file mode 100644 index 00000000..ad0fe7e2 --- /dev/null +++ b/mlperf/main.py @@ -0,0 +1,212 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import gc +import logging +import os +import sys + +from . import backend + +import mlperf_loadgen as lg + +_MLPERF_ID = "mixtral-8x7b" + +sys.path.insert(0, os.getcwd()) + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("main.py") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--scenario", + type=str, + choices=["Offline", "Server"], + default="Offline", + help="Scenario", + ) + parser.add_argument( + "--api-url", type=str, default=None, help="SAX published model path." + ) + parser.add_argument("--dataset-path", type=str, default=None, help="") + parser.add_argument("--tokenizer-path", type=str, default=None, help="") + parser.add_argument( + "--accuracy", action="store_true", help="Run accuracy mode" + ) + parser.add_argument("--is-stream", action="store_true", help="") + parser.add_argument( + "--input-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--output-mode", + type=str, + choices=["text", "tokenized"], + default="tokenized", + ) + parser.add_argument( + "--max-output-len", type=int, default=1024, help="Maximum output len" + ) + parser.add_argument( + "--audit-conf", + type=str, + default="audit.conf", + help="audit config for LoadGen settings during compliance runs", + ) + parser.add_argument( + "--mlperf-conf", + type=str, + default="mlperf.conf", + help="mlperf rules config", + ) + parser.add_argument( + "--user-conf", + type=str, + default="user.conf", + help="user config for user LoadGen settings such as target QPS", + ) + parser.add_argument( + "--total-sample-count", + type=int, + default=15000, + help="Number of samples to use in benchmark.", + ) + parser.add_argument( + "--perf-count-override", + type=int, + default=None, + help="Overwrite number of samples to use in benchmark.", + ) + parser.add_argument( + "--output-log-dir", + type=str, + default="output-logs", + help="Where logs are saved.", + ) + parser.add_argument( + "--enable-log-trace", + action="store_true", + help="Enable log tracing. This file can become quite large", + ) + parser.add_argument( + "--num-client-threads", + type=int, + default=200, + help="Number of client threads to use", + ) + parser.add_argument("--batch-size-exp", type=int, default=6, help="") + parser.add_argument("--log-pred-outputs", action="store_true", help="") + parser.add_argument( + "--log-interval", + type=int, + default=1000, + help="Logging interval in seconds", + ) + parser.add_argument( + "--user-conf-override-path", + type=str, + default="", + help="When given overrides the default user.conf path", + ) + + args = parser.parse_args() + return args + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + + +def main(): + args = get_args() + + settings = lg.TestSettings() + settings.scenario = scenario_map[args.scenario.lower()] + if args.user_conf_override_path: + user_conf = args.user_conf_override_path + else: + user_conf = args.user_conf + + settings.FromConfig(args.mlperf_conf, _MLPERF_ID, args.scenario) + settings.FromConfig(user_conf, _MLPERF_ID, args.scenario) + log.info("Mlperf config: %s", args.mlperf_conf) + log.info("User config: %s", user_conf) + + if args.accuracy: + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" + ) + else: + settings.mode = lg.TestMode.PerformanceOnly + settings.print_timestamps = True + + settings.use_token_latencies = True + + os.makedirs(args.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = args.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = args.enable_log_trace + + sut = backend.SUT( + scenario=args.scenario.lower(), + api_url=args.api_url, + is_stream=args.is_stream, + input_mode=args.input_mode, + output_mode=args.output_mode, + max_output_len=args.max_output_len, + dataset_path=args.dataset_path, + total_sample_count=args.total_sample_count, + tokenizer_path=args.tokenizer_path, + perf_count_override=args.perf_count_override, + num_client_threads=args.num_client_threads, + log_interval=args.log_interval, + batch_size_exp=args.batch_size_exp, + pred_outputs_log_path=os.path.join( + args.output_log_dir, "pred_outputs_logger.json" + ) + if args.log_pred_outputs + else None, + ) + + lgSUT = sut.sut # lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, sut.qsl, settings, log_settings, args.audit_conf + ) + + log.info("Run Completed!") + + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(sut.qsl) + + +if __name__ == "__main__": + # Disable garbage collection to avoid stalls when running tests. + gc.disable() + main() diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf new file mode 100644 index 00000000..9400d0af --- /dev/null +++ b/mlperf/mlperf.conf @@ -0,0 +1,98 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Llama2-70b benchmarks measures token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 1000 +mixtral-8x7b.Offline.min_query_count = 1000 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 4.0 diff --git a/mlperf/start_server.sh b/mlperf/start_server.sh new file mode 100755 index 00000000..74f9d6b3 --- /dev/null +++ b/mlperf/start_server.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +CACHE_LENGTH=3072 +INPUT_SIZE=512 +OUTPUT_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +pushd .. +python run_server.py \ + --model_name=mixtral \ + --batch_size=128 \ + --max_cache_length=$CACHE_LENGTH \ + --max_decode_length=$OUTPUT_SIZE \ + --context_length=$INPUT_SIZE \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 +popd \ No newline at end of file diff --git a/mlperf/user.conf b/mlperf/user.conf new file mode 100644 index 00000000..2b1fa841 --- /dev/null +++ b/mlperf/user.conf @@ -0,0 +1,3 @@ +mixtral-8x7b.Server.target_qps = 1.8 +mixtral-8x7b.Offline.target_qps = 4.0 + diff --git a/mlperf/warmup.py b/mlperf/warmup.py new file mode 100644 index 00000000..4df66551 --- /dev/null +++ b/mlperf/warmup.py @@ -0,0 +1,192 @@ +import argparse +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +import json +import random +import time +from typing import Any, AsyncGenerator, Optional +import os + + +import grpc +from jetstream.core.proto import jetstream_pb2 +from jetstream.core.proto import jetstream_pb2_grpc +from jetstream.engine.token_utils import load_vocab +from jetstream.third_party.llama3 import llama3_tokenizer +import numpy as np +from tqdm.asyncio import tqdm # pytype: disable=pyi-error +import pandas + + +@dataclass +class InputRequest: + prompt: str = "" + prompt_len: int = 0 + output: str = "" + output_len: int = 0 + sample_idx: int = -1 + + +@dataclass +class RequestFuncOutput: + input_request: Optional[InputRequest] = None + generated_token_list: list[str] = field(default_factory=list) + generated_text: str = "" + success: bool = False + latency: float = 0 + ttft: float = 0 + prompt_len: int = 0 + + # Flatten the structure and return only the necessary results + def to_dict(self): + return { + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "prompt_len": self.prompt_len, + "sample_idx": self.input_request.sample_idx, + } + + +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: + """Send grpc synchronous request since the current grpc server is sync.""" + options = [("grpc.keepalive_timeout_ms", 10000)] + async with grpc.aio.insecure_channel(api_url, options=options) as channel: + stub = jetstream_pb2_grpc.OrchestratorStub(channel) + print("Making request") + ttft = 0 + token_list = [] + request_start_time = time.perf_counter() + response = stub.Decode(request) + async for resp in response: + if ttft == 0: + ttft = time.perf_counter() - request_start_time + token_list.extend(resp.stream_content.samples[0].token_ids) + latency = time.perf_counter() - request_start_time + print("Done request: ", latency) + return token_list, ttft, latency + + +async def send_request( + api_url: str, + tokenizer: Any, + input_request: InputRequest, + pbar: tqdm, + session_cache: str, + priority: int, +) -> RequestFuncOutput: + """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. + token_ids = np.random.randint(0, 1000, input_request.request_len) + request = jetstream_pb2.DecodeRequest( + session_cache=session_cache, + token_content=jetstream_pb2.DecodeRequest.TokenContent( + token_ids=token_ids + ), + priority=priority, + max_tokens=input_request.output_len, + ) + output = RequestFuncOutput() + output.input_request = input_request + output.prompt_len = input_request.prompt_len + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) + output.ttft = ttft + output.latency = latency + output.generated_token_list = generated_token_list + # generated_token_list is a list of token ids, decode it to generated_text. + output.generated_text = "" + output.success = True + if pbar: + pbar.update(1) + return output + + +async def benchmark( + api_url: str, + max_length: int, + tokenizer: Any = None, + request_rate: float = 0, + disable_tqdm: bool = False, + session_cache: str = "", + priority: int = 100, +): + """Benchmark the online serving performance.""" + + print(f"Traffic request rate: {request_rate}") + + benchmark_start_time = time.perf_counter() + tasks = [] + interesting_buckets = [ + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + ] + + for length in interesting_buckets: + if length > max_length: + break + request = InputRequest() + request.request_len = length + print("send request of length", request.request_len) + tasks.append( + asyncio.create_task( + send_request( + api_url=api_url, + tokenizer=None, + input_request=request, + pbar=None, + session_cache=session_cache, + priority=priority, + ) + ) + ) + outputs = await asyncio.gather(*tasks) + + benchmark_duration = time.perf_counter() - benchmark_start_time + return benchmark_duration, outputs + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + api_url = f"{args.server}:{args.port}" + + benchmark_result, request_outputs = asyncio.run( + benchmark(api_url=api_url, max_length=args.max_length) + ) + print("DURATION:", benchmark_result) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Benchmark the online serving throughput." + ) + parser.add_argument( + "--server", + type=str, + default="0.0.0.0", + help="Server address.", + ) + parser.add_argument("--seed", type=int, default=0) + + parser.add_argument("--port", type=str, default=9000) + parser.add_argument("--max-length", type=int, default=512) + + parsed_args = parser.parse_args() + main(parsed_args)