diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 928b7284fe..7be199d52a 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -36,6 +36,8 @@ def constant_fold( # The constants are created on CPU to save GPU memory for TensorRT compilation. # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): + if node.target == torch.ops.aten.embedding.default: + continue replace_node_with_constant( gm, node, torch.nn.Parameter(constant, requires_grad=False) ) @@ -103,7 +105,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.quantization_ops: Set[torch._ops.OpOverload] = set() try: # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered - import modelopt.torch.quantization as mtq + import modelopt.torch.quantization as mtq # noqa: F401 assert torch.ops.tensorrt.quantize_op.default self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default) diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index ca37316ea8..fe2ba3073c 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): compile_settings = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_long_and_double": params.get("truncate", False), - "use_python_runtime": params.get("use_python_runtime", False), + "truncate_double": params.get("truncate", False), } if precision == "int8": @@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): ir="dynamo", enabled_precisions={precision_to_dtype(precision)}, min_block_size=params.get("min_block_size", 1), - debug=False, - truncate_long_and_double=params.get("truncate", False), + truncate_double=params.get("truncate", False), immutable_weights=params.get("immutable_weights", True), strip_engine_weights=params.get("strip_engine_weights", False), refit_identical_engine_weights=params.get( @@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): cache_built_engines=params.get("cache_built_engines", False), reuse_cached_engines=params.get("reuse_cached_engines", False), use_python_runtime=params.get("use_python_runtime", False), + optimization_level=params.get("optimization_level", 5), ) end_compile = timeit.default_timer() compile_time_s = end_compile - start_compile @@ -437,25 +436,30 @@ def run_tensorrt( precision, batch_size=1, ): - # Export an ONNX model and convert to TRT - torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx") logger = trt.Logger(trt.Logger.WARNING) - builder = trt.Builder(logger) - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) - parser = trt.OnnxParser(network, logger) - success = parser.parse_from_file("./tmp.onnx") - if not success: - raise ValueError("ONNX conversion failed") - - config = builder.create_builder_config() - if precision == "fp16": - config.set_flag(trt.BuilderFlag.FP16) - start_compile = timeit.default_timer() - serialized_engine = builder.build_serialized_network(network, config) - end_compile = timeit.default_timer() - compile_time_s = end_compile - start_compile + compile_time_s = 0 + if params["is_trt_engine"]: + serialized_engine = model + else: + # Export an ONNX model and convert to TRT + torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx") + builder = trt.Builder(logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + ) + parser = trt.OnnxParser(network, logger) + success = parser.parse_from_file("./tmp.onnx") + if not success: + raise ValueError("ONNX conversion failed") + + config = builder.create_builder_config() + if precision == "fp16": + config.set_flag(trt.BuilderFlag.FP16) + config.builder_optimization_level = params.get("optimization_level", 5) + start_compile = timeit.default_timer() + serialized_engine = builder.build_serialized_network(network, config) + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile # Deserialize the TensorRT engine with trt.Runtime(logger) as runtime: engine = runtime.deserialize_cuda_engine(serialized_engine) @@ -463,31 +467,66 @@ def run_tensorrt( print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size) iters = params.get("iterations", 20) - # Compiling the bindings - bindings = engine.num_bindings * [None] - k = 0 - for idx, _ in enumerate(bindings): - dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx)) - shape = tuple(engine.get_binding_shape(idx)) - device = torch_device_from_trt(engine.get_location(idx)) - if not engine.binding_is_input(idx): - # Output bindings - output = torch.empty(size=shape, dtype=dtype, device=device) - bindings[idx] = output.data_ptr() - else: - # Input bindings - bindings[idx] = input_tensors[k].data_ptr() - k += 1 + # Get I/O tensor information using TensorRT 10 API + input_names = [] + output_names = [] + input_dtypes = [] + output_dtypes = [] + input_shapes = [] + output_shapes = [] + + for i in range(engine.num_io_tensors): + tensor_name = engine.get_tensor_name(i) + tensor_mode = engine.get_tensor_mode(tensor_name) + tensor_dtype = engine.get_tensor_dtype(tensor_name) + tensor_shape = engine.get_tensor_shape(tensor_name) + + if tensor_mode == trt.TensorIOMode.INPUT: + input_names.append(tensor_name) + input_dtypes.append(torch_dtype_from_trt(tensor_dtype)) + input_shapes.append(tuple(tensor_shape)) + else: # trt.TensorIOMode.OUTPUT + output_names.append(tensor_name) + output_dtypes.append(torch_dtype_from_trt(tensor_dtype)) + output_shapes.append(tuple(tensor_shape)) + + # Create output tensors + output_tensors = [] + for i, (shape, dtype) in enumerate(zip(output_shapes, output_dtypes)): + output = torch.empty(size=shape, dtype=dtype, device="cuda") + output_tensors.append(output) timings = [] with engine.create_execution_context() as context: + # Set input tensor addresses + for i, (input_name, input_tensor) in enumerate(zip(input_names, input_tensors)): + context.set_tensor_address(input_name, input_tensor.data_ptr()) + + # Set output tensor addresses + for output_name, output_tensor in zip(output_names, output_tensors): + context.set_tensor_address(output_name, output_tensor.data_ptr()) + + # Create a dedicated stream for TensorRT execution + dedicated_stream = torch.cuda.Stream() + current_stream = torch.cuda.current_stream() + + # Warm up for i in range(WARMUP_ITER): - context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) + # Wait for current stream to finish + dedicated_stream.wait_stream(current_stream) + context.execute_async_v3(dedicated_stream.cuda_stream) + # Wait for TensorRT stream to finish + current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() + # Performance measurement for i in range(iters): start_time = timeit.default_timer() - context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) + # Wait for current stream to finish + dedicated_stream.wait_stream(current_stream) + context.execute_async_v3(dedicated_stream.cuda_stream) + # Wait for TensorRT stream to finish + current_stream.wait_stream(dedicated_stream) torch.cuda.synchronize() end_time = timeit.default_timer() meas_time = end_time - start_time @@ -504,7 +543,6 @@ def run( params, precision, batch_size=1, - is_trt_engine=False, model_torch=None, ): for backend in backends: @@ -551,7 +589,6 @@ def run( input_tensors, params, precision, - is_trt_engine, batch_size, ) run_dynamo(model_torch, input_tensors, params, precision, batch_size) @@ -569,7 +606,7 @@ def run( ) elif backend == "tensorrt": run_tensorrt( - model_torch, + model, input_tensors, params, precision, @@ -643,6 +680,12 @@ def run( action="store_true", help="Truncate long and double weights in the network in Torch-TensorRT", ) + arg_parser.add_argument( + "--optimization_level", + type=int, + default=5, + help="Builder optimization level for TensorRT", + ) arg_parser.add_argument( "--is_trt_engine", action="store_true", @@ -702,8 +745,13 @@ def run( # Load TorchScript model, if provided if os.path.exists(model_name): - print("Loading user provided torchscript model: ", model_name) - model = torch.jit.load(model_name).cuda().eval() + if params["is_trt_engine"]: + with open(model_name, "rb") as f: + model = f.read() + print("Loading user provided trt engine: ", model_name) + else: + print("Loading user provided torchscript model: ", model_name) + model = torch.jit.load(model_name).cuda().eval() # Load PyTorch Model, if provided if len(model_name_torch) > 0 and os.path.exists(model_name_torch): @@ -746,7 +794,6 @@ def run( params, precision, batch_size, - is_trt_engine, model_torch=model_torch, ) diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 5dae807892..0fd38e6447 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -176,6 +176,8 @@ def torch_dtype_from_trt(dtype): return torch.bool elif dtype == trt.int32: return torch.int32 + elif dtype == trt.int64: + return torch.int64 elif dtype == trt.float16: return torch.float16 elif dtype == trt.float32: