diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index a1eb22eba6987..195a58432737b 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -40,7 +40,7 @@ void addTosaToLinalgPasses( // Note: Default to 'none' level unless otherwise specified. std::optional validationOptions = tosa::TosaValidationOptions{ - {"none"}, false, tosa::TosaLevelEnum::None}); + {"none"}, {"none"}, false, tosa::TosaLevelEnum::None}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt index cc8d5ed9b0044..0a855d701d7b8 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -12,3 +12,14 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen) set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td) mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa") add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOpBase.td) +mlir_tablegen(TosaEnums.h.inc -gen-enum-decls) +mlir_tablegen(TosaEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTosaEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaAvailability.h.inc -gen-avail-interface-decls) +mlir_tablegen(TosaAvailability.cpp.inc -gen-avail-interface-defs) +mlir_tablegen(TosaOpAvailabilityImpl.inc -gen-tosa-avail-impls) +add_public_tablegen_target(MLIRTosaAvailabilityIncGen) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h new file mode 100644 index 0000000000000..86fb4077b9207 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -0,0 +1,84 @@ +//===- TargetEnv.h - Tosa target environment utilities ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for Tosa target environment (implementation). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TARGETENV_H +#define MLIR_DIALECT_TOSA_IR_TARGETENV_H + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallSet.h" + +namespace mlir { +namespace tosa { + +/// This class represents the capability enabled in the target implementation +/// such as profile, extension, and level. +class TargetEnv { +public: + TargetEnv() {} + explicit TargetEnv(const SmallVectorImpl &profiles, + const SmallVectorImpl &extensions) { + for (Profile prof : profiles) + enabledProfiles.insert(prof); + + for (Extension ext : extensions) + enabledExtensions.insert(ext); + } + + void addProfile(Profile p) { enabledProfiles.insert(p); } + void addExtension(Extension e) { enabledExtensions.insert(e); } + + // TODO implement the following utilities. + // Version getSpecVersion() const; + // TosaLevel getLevel() const; + + // Returns true if the given profile is allowed. + bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; } + + bool allowsAnyOf(ArrayRef profs) const { + const auto *chosen = llvm::find_if( + profs, [this](tosa::Profile prof) { return allows(prof); }); + return chosen != profs.end() ? true : false; + } + + bool allowsAllOf(ArrayRef profs) const { + bool is_allowed = true; + llvm::for_each(profs, + [&](tosa::Profile prof) { is_allowed &= allows(prof); }); + return is_allowed; + } + + // Returns true if the given extension is allowed. + bool allows(Extension ext) const { return enabledExtensions.count(ext) != 0; } + + bool allowsAnyOf(ArrayRef exts) const { + const auto *chosen = llvm::find_if( + exts, [this](tosa::Extension ext) { return allows(ext); }); + return chosen != exts.end() ? true : false; + } + + bool allowsAllOf(ArrayRef exts) const { + bool is_allowed = true; + llvm::for_each(exts, + [&](tosa::Extension ext) { is_allowed &= allows(ext); }); + return is_allowed; + } + +private: + llvm::SmallSet enabledProfiles; + llvm::SmallSet enabledExtensions; +}; + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TARGETENV_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h new file mode 100644 index 0000000000000..2617a902c3a0d --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h @@ -0,0 +1,403 @@ +// The profile-based compliance content below is auto-generated by the script +// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git +profileComplianceMap = { + {"tosa.argmax", + {{{Profile::pro_int}, {{i8T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, + {"tosa.avg_pool2d", + {{{Profile::pro_int}, {{i8T, i32T, i8T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.conv2d", + {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp32T, fp16T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.conv3d", + {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp32T, fp16T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.depthwise_conv2d", + {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp32T, fp16T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.fully_connected", + {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp32T, fp32T}, + {fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.matmul", + {{{Profile::pro_int}, {{i8T, i8T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.max_pool2d", + {{{Profile::pro_int}, {{i8T, i8T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.transpose_conv2d", + {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}}, + {{Profile::pro_fp}, + {{fp16T, fp16T, fp16T, fp16T, fp16T}, + {fp16T, fp16T, fp16T, fp32T, fp16T}, + {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.clamp", + {{{Profile::pro_int}, {{i8T, i8T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.add", + {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.arithmetic_right_shift", + {{{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.bitwise_and", + {{{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.bitwise_or", + {{{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.bitwise_xor", + {{{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.intdiv", + {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}}}, + {"tosa.logical_and", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}}, + {"tosa.logical_left_shift", + {{{Profile::pro_int, Profile::pro_fp}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.logical_right_shift", + {{{Profile::pro_int, Profile::pro_fp}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {"tosa.logical_or", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}}, + {"tosa.logical_xor", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}}}, + {"tosa.maximum", + {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.minimum", + {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.mul", + {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}}, + {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.pow", + {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.sub", + {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}}, + {"tosa.abs", + {{{Profile::pro_int}, {{i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.bitwise_not", + {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}}, + {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}}, + {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.logical_not", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}}, + {"tosa.negate", + {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reciprocal", + {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.select", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}}, + {{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.equal", + {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {"tosa.greater", + {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {"tosa.greater_equal", + {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {"tosa.reduce_all", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}}, + {"tosa.reduce_any", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}}, + {"tosa.reduce_max", + {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reduce_min", + {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reduce_product", + {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reduce_sum", + {{{Profile::pro_int}, {{i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.concat", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.pad", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reshape", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.reverse", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.slice", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.tile", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.transpose", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}, + {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.gather", + {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.scatter", + {{{Profile::pro_int}, + {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {"tosa.resize", + {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.cast", + {{{Profile::pro_int}, + {{boolT, i8T}, + {boolT, i16T}, + {boolT, i32T}, + {i8T, boolT}, + {i8T, i16T}, + {i8T, i32T}, + {i16T, boolT}, + {i16T, i8T}, + {i16T, i32T}, + {i32T, boolT}, + {i32T, i8T}, + {i32T, i16T}}}, + {{Profile::pro_fp}, + {{i8T, fp16T}, + {i8T, fp32T}, + {i16T, fp16T}, + {i16T, fp32T}, + {i32T, fp16T}, + {i32T, fp32T}, + {fp16T, i8T}, + {fp16T, i16T}, + {fp16T, i32T}, + {fp16T, fp32T}, + {fp32T, i8T}, + {fp32T, i16T}, + {fp32T, i32T}, + {fp32T, fp16T}}}}}, + {"tosa.rescale", + {{{Profile::pro_int}, + {{i8T, i8T}, + {i8T, i16T}, + {i8T, i32T}, + {i16T, i8T}, + {i16T, i16T}, + {i16T, i32T}, + {i32T, i8T}, + {i32T, i16T}, + {i32T, i32T}}}}}, + {"tosa.const", + {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}}, + {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {"tosa.identity", + {{{Profile::pro_int}, + {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, + {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {"tosa.dim", + {{{Profile::pro_int, Profile::pro_fp}, {{boolT}}}, + {{Profile::pro_int}, {{i8T}, {i16T}, {i32T}}}, + {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, +}; + +extensionComplianceMap = { + {"tosa.argmax", + {{{Extension::int16}, {{i16T, i32T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, + {{Extension::bf16}, {{bf16T, i32T}}}}}, + {"tosa.avg_pool2d", + {{{Extension::int16}, {{i16T, i32T, i16T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}}, + {"tosa.conv2d", + {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}}, + {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {"tosa.conv3d", + {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}}, + {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {"tosa.depthwise_conv2d", + {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}}, + {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}}, + {"tosa.fully_connected", + {{{Extension::int4}, {{i8T, i4T, i32T, i32T}}}, + {{Extension::int16}, {{i16T, i8T, i48T, i48T}}}, + {{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}}, + {"tosa.matmul", + {{{Extension::int16}, {{i16T, i16T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}}, + {"tosa.max_pool2d", + {{{Extension::int16}, {{i16T, i16T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}}, + {"tosa.transpose_conv2d", + {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}}, + {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T, fp16T, fp16T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T, fp16T, fp16T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {"tosa.clamp", + {{{Extension::int16}, {{i16T, i16T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}}, + {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, + {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, + {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, + {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.concat", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.pad", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reshape", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.reverse", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.slice", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.tile", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.transpose", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.gather", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.scatter", + {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {"tosa.resize", + {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.cast", + {{{Extension::bf16}, + {{i8T, bf16T}, + {i16T, bf16T}, + {i32T, bf16T}, + {bf16T, i8T}, + {bf16T, i16T}, + {bf16T, i32T}, + {bf16T, fp32T}, + {fp32T, bf16T}}}, + {{Extension::bf16, Extension::fp8e4m3}, + {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}}, + {{Extension::bf16, Extension::fp8e5m2}, + {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}}, + {{Extension::fp8e4m3}, + {{fp8e4m3T, fp16T}, + {fp8e4m3T, fp32T}, + {fp16T, fp8e4m3T}, + {fp32T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, + {{fp8e5m2T, fp16T}, + {fp8e5m2T, fp32T}, + {fp16T, fp8e5m2T}, + {fp32T, fp8e5m2T}}}}}, + {"tosa.rescale", + {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}}, + {"tosa.const", + {{{Extension::int4}, {{i4T}}}, + {{Extension::int16}, {{i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T}}}}}, + {"tosa.identity", + {{{Extension::int4}, {{i4T, i4T}}}, + {{Extension::int16}, {{i48T, i48T}}}, + {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {"tosa.dim", + {{{Extension::fp8e4m3}, {{fp8e4m3T}}}, + {{Extension::fp8e5m2}, {{fp8e5m2T}}}, + {{Extension::bf16}, {{bf16T}}}}}, +}; +// End of auto-generated metadata diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 862d98ad436a6..13bbba2b492fa 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -14,8 +14,15 @@ #define TOSA_OP_BASE include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" + +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + //===----------------------------------------------------------------------===// // The TOSA Dialect. //===----------------------------------------------------------------------===// @@ -200,6 +207,190 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder< input, paddings, pad_value); }]>; +// Wrapper over base I32EnumAttr to set common fields. +class Tosa_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tosa"; +} + +class Tosa_I32EnumAttr cases> + : EnumAttr, mnemonic> { + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 1.5. +// +// Profile: +// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values. +// FP : Floating-Point Inference. Primarily FP16 and FP32 operations. +// +// Extension: +// INT16 : 16-bit integer operations. +// INT4 : 4-bit integer weights. +// BF16 : BFloat16 operations. +// FP8 : 8-bit floating-point operations E4M3. +// FP8 : 8-bit floating-point operations E5M2. +// FFT : Fast Fourier Transform operations. +// VARIABLE : Stateful variable operations. +//===----------------------------------------------------------------------===// + +def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; +def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; +def Tosa_NONE : I32EnumAttrCase<"none", 3>; + +def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; +def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; +def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>; +def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>; +def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>; +def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>; +def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>; +def Tosa_EXT_NONE : I32EnumAttrCase<"none", 8>; + +def Tosa_ExtensionAttr + : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [ + Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, + Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE + ]>; + +def Tosa_ExtensionArrayAttr + : TypedArrayAttrBase; + +def Tosa_ProfileAttr + : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", + [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>; + +def Tosa_ProfileArrayAttr + : TypedArrayAttrBase; + +// The base class for defining op availability dimensions. +class Availability { + // The following are fields for controlling the generated C++ OpInterface. + + // The namespace for the generated C++ OpInterface subclass. + string cppNamespace = "::mlir::tosa"; + + // The name for the generated C++ OpInterface subclass. + string interfaceName = ?; + + // The description for the generated C++ OpInterface subclass. + string interfaceDescription = ""; + + // The query function's return type in the generated C++ OpInterface subclass. + string queryFnRetType = ?; + + // The query function's name in the generated C++ OpInterface subclass. + string queryFnName = ?; + + // The logic for merging two availability requirements. + code mergeAction = ?; + + // The initializer for the final availability requirement. + string initializer = ?; + + // An availability instance's type. + string instanceType = ?; + + // The following are fields for a concrete availability instance. + + // The code for preparing a concrete instance. This should be C++ statements + // and will be generated before the `mergeAction` logic. + code instancePreparation = ""; + + // The availability requirement carried by a concrete instance. + string instance = ?; +} + + +class Profile profiles> : Availability { + let interfaceName = "QueryProfileInterface"; + let interfaceDescription = [{ + Querying interface for the supported set of Tosa profile. + + This interface provides a `getProfiles()` method to query + the supported set of Tosa profile. The returned value is a + list of `mlir::Tosa::Profile` enum number. + }]; + + let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<" + "::mlir::tosa::Profile>, 1>"; + let queryFnName = "getProfiles"; + + let mergeAction = !if( + !empty(profiles), "", "$overall.emplace_back($instance)"); + + let initializer = "{}"; + + let instanceType = "::llvm::ArrayRef<::mlir::tosa::Profile>"; + + // Pack all profiles as a static array and get its reference. + let instancePreparation = !if(!empty(profiles), "", + "static const ::mlir::tosa::Profile profs[] = {" # + !interleave(!foreach(prof, profiles, + "::mlir::tosa::Profile::" # prof.symbol), ", ") # + "}; " # + "ArrayRef<::mlir::tosa::Profile> " # + "ref(profs, std::size(profs));"); + + let instance = "ref"; +} + +class Extension extensions> : Availability { + let interfaceName = "QueryExtensionInterface"; + let interfaceDescription = [{ + Querying interface for the supported set of TOSA extension. + + This interface provides a `getExtensions()` method to query + the supported set of Tosa extension. The returned value is a + list of `mlir::Tosa::Extension` enum number. + }]; + + let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<" + "::mlir::tosa::Extension>, 1>"; + let queryFnName = "getExtensions"; + + let mergeAction = !if( + !empty(extensions), "", "$overall.emplace_back($instance)"); + + let initializer = "{}"; + + let instanceType = "::llvm::ArrayRef<::mlir::tosa::Extension>"; + + // Pack all extensions as a static array and get its reference. + let instancePreparation = !if(!empty(extensions), "", + "static const ::mlir::tosa::Extension exts[] = {" # + !interleave(!foreach(ext, extensions, + "::mlir::tosa::Extension::" # ext.symbol), ", ") # + "}; " # + "ArrayRef<::mlir::tosa::Extension> " # + "ref(exts, std::size(exts));"); + + let instance = "ref"; +} + +//===----------------------------------------------------------------------===// +// TOSA Interfaces. +//===----------------------------------------------------------------------===// + +def QueryProfileInterface : OpInterface<"QueryProfileInterface"> { + let cppNamespace = "::mlir::tosa"; + let methods = [InterfaceMethod< + "get supported profiles", + "::llvm::SmallVector<::llvm::ArrayRef<::mlir::tosa::Profile>, 1>", + "getProfiles">]; +} + +def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> { + let cppNamespace = "::mlir::tosa"; + let methods = [InterfaceMethod< + "get supported extensions", + "::llvm::SmallVector<::llvm::ArrayRef<::mlir::tosa::Extension>, 1>", + "getExtensions">]; +} + //===----------------------------------------------------------------------===// // TOSA Operator Trait. //===----------------------------------------------------------------------===// @@ -223,7 +414,17 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> { class Tosa_Op traits = []> : Op, + DeclareOpInterfaceMethods, TosaResolvableShapeOperands])> { + + // Default availability specification. + list availability = [ + Profile<[]>, + Extension<[]>]; + + // When not set, manual implementation of these methods is required. + bit autogenAvailability = 1; } class Tosa_ElementwiseOp traits = []> : diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 069073bc2d164..358e5dabfeb62 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -29,9 +29,16 @@ // TOSA dialect and structs includes. //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaEnums.h.inc" #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc" #include "mlir/Transforms/DialectConversion.h" +//===----------------------------------------------------------------------===// +// TOSA operation validation includes. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaAvailability.h.inc" + namespace mlir { class PatternRewriter; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 7cdf79f4dc59d..3de1c21f40b43 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -50,6 +50,11 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { Tosa_Tensor: $output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; } @@ -86,6 +91,11 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; let hasVerifier = 1; } @@ -118,6 +128,11 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_ConvOpQuantInfoBuilder]; let hasVerifier = 1; } @@ -149,6 +164,11 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { Tosa_Tensor5D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_ConvOpQuantInfoBuilder]; let hasVerifier = 1; } @@ -181,6 +201,11 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_ConvOpQuantInfoBuilder]; let hasVerifier = 1; } @@ -218,6 +243,11 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> { Tosa_Tensor3D:$output_imag ); + list availability = [ + Profile<[]>, + Extension<[Tosa_EXT_FFT]>, + ]; + let assemblyFormat = [{ $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,` type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)` @@ -247,6 +277,11 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { Tosa_Tensor3D:$c ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_MatMulOpQuantInfoBuilder]; } @@ -276,6 +311,11 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasCanonicalizer = 1; } @@ -310,6 +350,11 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> { Tosa_Tensor3D:$output_imag ); + list availability = [ + Profile<[]>, + Extension<[Tosa_EXT_FFT]>, + ]; + let assemblyFormat = [{ $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)` }]; @@ -343,6 +388,11 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_TransConvOpQuantInfoBuilder]; let hasVerifier = 1; } @@ -377,6 +427,11 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_BF16]>, + ]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -402,6 +457,11 @@ def Tosa_SigmoidOp : Tosa_ElementwiseUnaryOp<"sigmoid"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -424,6 +484,11 @@ def Tosa_TanhOp : Tosa_ElementwiseUnaryOp<"tanh"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -447,6 +512,11 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -488,6 +558,11 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [ Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -512,6 +587,11 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -535,6 +615,11 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [ let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -558,6 +643,11 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [ let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -581,6 +671,11 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [ let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -603,6 +698,11 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementT Tosa_Int32Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let hasFolder = 1; } @@ -627,6 +727,11 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [ let results = (outs Tosa_I1Tensor:$z ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -649,6 +754,11 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -671,6 +781,11 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -694,6 +809,11 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [ let results = (outs Tosa_I1Tensor:$z ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -717,6 +837,11 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [ let results = (outs Tosa_I1Tensor:$z ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -741,6 +866,11 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [ let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -765,6 +895,11 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [ let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } def MulOperandsAndResultElementType : @@ -799,6 +934,11 @@ def Tosa_MulOp : Tosa_Op<"mul", [ Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -825,6 +965,11 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -847,6 +992,11 @@ def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -882,6 +1032,11 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let assemblyFormat = [{ $input1 `,` $table attr-dict `:` `(` type($input1) `,` type($table) `)` `->` type($output) }]; @@ -919,6 +1074,11 @@ def Tosa_AbsOp : Tosa_ElementwiseUnaryOp<"abs"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -939,6 +1099,11 @@ def Tosa_BitwiseNotOp : Tosa_ElementwiseUnaryOp<"bitwise_not"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -958,6 +1123,11 @@ def Tosa_CeilOp : Tosa_ElementwiseUnaryOp<"ceil"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -977,6 +1147,11 @@ def Tosa_ClzOp : Tosa_ElementwiseUnaryOp<"clz"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -996,6 +1171,11 @@ def Tosa_CosOp : Tosa_ElementwiseUnaryOp<"cos"> { let results = (outs Tosa_FloatTensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1016,6 +1196,11 @@ def Tosa_ExpOp : Tosa_ElementwiseUnaryOp<"exp"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -1036,6 +1221,11 @@ def Tosa_FloorOp : Tosa_ElementwiseUnaryOp<"floor"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1056,6 +1246,11 @@ def Tosa_LogOp : Tosa_ElementwiseUnaryOp<"log"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -1076,6 +1271,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> { let results = (outs Tosa_I1Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; } //===----------------------------------------------------------------------===// @@ -1098,6 +1298,11 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_UnaryOpQuantInfoBuilder]; let hasFolder = 1; @@ -1122,6 +1327,11 @@ def Tosa_ReciprocalOp : Tosa_ElementwiseUnaryOp<"reciprocal"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let extraClassDeclaration = [{ /// Return the reciprocal result on the operand. static inline APFloat calcOneElement(const APFloat &operand) { @@ -1152,6 +1362,11 @@ def Tosa_RsqrtOp : Tosa_ElementwiseUnaryOp<"rsqrt"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1171,6 +1386,11 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> { let results = (outs Tosa_FloatTensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1198,6 +1418,12 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { let results = (outs Tosa_Tensor:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasCanonicalizeMethod = 1; let hasFolder = 1; @@ -1234,6 +1460,11 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [ Tosa_I1Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let extraClassDeclaration = [{ /// Returns when two result types are compatible for this op; method used by /// InferTypeOpInterface. @@ -1262,6 +1493,11 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> { Tosa_I1Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -1285,6 +1521,11 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", Tosa_I1Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; } @@ -1312,6 +1553,11 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1346,6 +1592,11 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1381,6 +1632,11 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1417,6 +1673,11 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1452,6 +1713,11 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1486,6 +1752,11 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1526,6 +1797,11 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasCanonicalizer = 1; let hasFolder = 1; @@ -1573,6 +1849,11 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { Tosa_RankedTensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let builders = [Tosa_PadOpQuantInfoBuilder, Tosa_ExplicitValuePadOpQuantInfoBuilder]; @@ -1605,6 +1886,11 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> { Tosa_RankedTensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let extraClassDeclaration = [{ /// Returns true when two result types are compatible for this op; /// Method used by InferTypeOpInterface. @@ -1637,6 +1923,11 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [ Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; @@ -1665,6 +1956,11 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; @@ -1688,6 +1984,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let extraClassDeclaration = [{ LogicalResult getConstantMultiples(llvm::SmallVector &multiples); }]; @@ -1717,6 +2018,11 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", outs Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let extraClassDeclaration = [{ LogicalResult getConstantPerms(llvm::SmallVector &perms); }]; @@ -1750,6 +2056,11 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> { let results = (outs Tosa_Tensor3D:$output ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1772,6 +2083,11 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> { let results = (outs Tosa_Tensor3D:$values_out ); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; } //===----------------------------------------------------------------------===// @@ -1806,6 +2122,11 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> { Tosa_Tensor4D:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; } @@ -1857,6 +2178,11 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; let hasFolder = 1; @@ -1908,6 +2234,11 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure, Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT]>, + Extension<[Tosa_EXT_INT16]>, + ]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -1944,6 +2275,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure, TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let hasFolder = 1; let hasVerifier = 1; } @@ -1968,6 +2304,11 @@ def Tosa_IdentityOp: Tosa_Op<"identity", [Pure, Tosa_Tensor:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, + ]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -2023,6 +2364,11 @@ def Tosa_CustomOp : Tosa_Op<"custom"> { Variadic:$output_list ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -2057,6 +2403,11 @@ def Tosa_IfOp : Tosa_Op<"cond_if", Variadic:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch @@ -2093,6 +2444,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ Variadic:$output ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; + let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h new file mode 100644 index 0000000000000..a831bae12f3c1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -0,0 +1,163 @@ +//===- TosaProfileCompliance.h - Tosa Profile-based Compliance Validation -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Support/TypeID.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Type Compilance Definition +//===----------------------------------------------------------------------===// + +typedef struct { + mlir::TypeID typeID; + uint32_t bitWidth; +} TypeInfo; + +enum CheckCondition { + // Valid when any of the profile (extension) requirement is meet. + anyOf, + // Valid when all of the profile (extension) requirement are meet. + allOf, + invalid +}; + +template +struct OpComplianceInfo { + // Certain operations require multiple modes enabled. + // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3. + SmallVector mode; + SmallVector> operandTypeInfoSet; + CheckCondition condition = CheckCondition::anyOf; +}; + +using OperationProfileComplianceMap = + std::unordered_map>>; +using OperationExtensionComplianceMap = + std::unordered_map>>; + +//===----------------------------------------------------------------------===// +// Tosa Profile And Extension Information Depot +//===----------------------------------------------------------------------===// + +class ProfileInfoDepot { +public: + ProfileInfoDepot(Operation *op) { + if (failed(populatationDispatch(op))) + op->emitOpError() << "fail to populate the profile info\n"; + } + + void addType(Type t) { tyInfo.push_back(convertTypeToInfo(t)); } + void addValue(Value v) { tyInfo.push_back(convertValueToInfo(v)); } + SmallVector getInfo() { return tyInfo; } + +private: + TypeInfo convertTypeToInfo(Type type) { + return {type.getTypeID(), type.getIntOrFloatBitWidth()}; + } + + TypeInfo convertValueToInfo(Value value) { + return convertTypeToInfo(getElementTypeOrSelf(value.getType())); + } + + LogicalResult populatationDispatch(Operation *op); + + void populateProfileInfo(ValueRange operands, Value output); + + // Base + template + void populateProfileInfo(T op) { + op->emitOpError() << "profile requirement for this op has not been defined"; + } + // For conv2d, conv3d, transpose_conv2d, and depthwise_conv2d. + template + void populateProfileInfoConv(T op); + + // For pad, reshape, slice, tile, and transpose. + template + void populateProfileInfoDataLayout(T op); + +private: + SmallVector tyInfo; +}; + +//===----------------------------------------------------------------------===// +// Tosa Profile And Extension Compliance Checker +//===----------------------------------------------------------------------===// + +class TosaProfileCompliance { +public: + explicit TosaProfileCompliance(); + + // Accessor of the compliance info map. + template + std::unordered_map>> + getProfileComplianceMap() { + // Only profile and extension compliance info are provided. + return {}; + } + + // Verify if the operation is allowed to be executed in the given target + // environment. + LogicalResult checkProfile(Operation *op, const tosa::TargetEnv &targetEnv); + LogicalResult checkExtension(Operation *op, const tosa::TargetEnv &targetEnv); + + template + LogicalResult checkProfileOrExtension( + Operation *op, const tosa::TargetEnv &targetEnv, + const SmallVector> &specDefinedProfileSet); + + bool isSameTypeInfo(TypeInfo a, TypeInfo b) { + return a.typeID == b.typeID && a.bitWidth == b.bitWidth; + } + + // Find the required profiles or extensions from the compliance info according + // to the operand type combination. + template + SmallVector findMatchedProfile(Operation *op, + SmallVector> compInfo, + CheckCondition &condition); + + SmallVector getCooperativeProfiles(Extension ext) { + switch (ext) { + case Extension::int16: + case Extension::int4: + return {Profile::pro_int}; + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + return {Profile::pro_fp}; + case Extension::variable: + return {Profile::pro_fp, Profile::pro_int}; + case Extension::none: + return {}; + }; + } + + // Debug utilites. + template + SmallVector stringifyProfile(ArrayRef profiles); + + template + SmallVector + stringifyProfile(const SmallVector> &profileSet); + +private: + OperationProfileComplianceMap profileComplianceMap; + OperationExtensionComplianceMap extensionComplianceMap; +}; + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSAPROFILECOMPILANCE_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td index 597dc32e84402..82cfe01865853 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -30,6 +30,10 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> { class Tosa_ShapeOp traits = []> : Tosa_Op { + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[]>, + ]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index f9f25da1b649d..8756cb9e5de3a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -96,6 +96,11 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { OptionalAttr:$initial_value ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_VARIABLE]>, + ]; + let assemblyFormat = [{ $name attr-dict @@ -118,6 +123,11 @@ def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> { AnyType:$value ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_VARIABLE]>, + ]; + let assemblyFormat = [{ $name attr-dict `,` $value `:` type($value) }]; @@ -141,6 +151,11 @@ def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> { AnyType:$value ); + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_VARIABLE]>, + ]; + let assemblyFormat = [{ $name attr-dict `:` type($value) }]; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 565970367e5dc..33bbc069c521d 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" @@ -48,28 +49,6 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); -struct ValidationOptions { - /// Validate if operations match for the given profile. - TosaProfileEnum profile = TosaProfileEnum::Undefined; - ValidationOptions &setProfile(TosaProfileEnum profile) { - this->profile = profile; - return *this; - } - /// Verify if the properties of certain operations align the spec requirement. - bool strictOperationSpecAlignment = false; - ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) { - strictOperationSpecAlignment = enable; - return *this; - } - /// Validate if operator parameters are within specfication for the given - /// level. - TosaLevelEnum level = TosaLevelEnum::EightK; - ValidationOptions &setLevel(TosaLevelEnum level) { - this->level = level; - return *this; - } -}; - #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index dac67633769c7..f6ead2b6ba3dd 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -71,16 +71,6 @@ def TosaOptionalDecompositions let constructor = "tosa::createTosaOptionalDecompositions()"; } -def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile", - [ - I32EnumAttrCase<"BaseInference", 0, "bi">, - I32EnumAttrCase<"MainInference", 1, "mi">, - I32EnumAttrCase<"MainTraining", 2, "mt">, - I32EnumAttrCase<"Undefined", 3, "none"> - ]>{ - let cppNamespace = "mlir::tosa"; -} - def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", [ I32EnumAttrCase<"None", 0, "none">, @@ -99,7 +89,9 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let options = [ ListOption<"profile", "profile", "std::string", "Validate if operations match for the given profile set">, - Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", + ListOption<"extension", "extension", "std::string", + "Validate if operations match for the given extension set">, + Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, Option<"level", "level", "mlir::tosa::TosaLevelEnum", diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 8dfa55bef74fc..bfadebba12708 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -117,7 +117,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; validationOptions.profile = {"none"}; - validationOptions.StrictOperationSpecAlignment = true; + validationOptions.extension = {"none"}; + validationOptions.strictOpSpecAlignment = false; validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index e6999f6fa0d85..b1fac8c85a204 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -12,6 +12,8 @@ add_mlir_dialect_library(MLIRTosaDialect MLIRTosaDialectBytecodeIncGen MLIRTosaOpsIncGen MLIRTosaInterfacesIncGen + MLIRTosaEnumsIncGen + MLIRTosaAvailabilityIncGen MLIRShardingInterfaceIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index d21e218308df7..e9c33e1b1bf10 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -42,7 +42,10 @@ using namespace mlir::tosa; // Tosa dialect interface includes. //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc" +#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc" #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" +#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc" namespace { #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 9c3345b617cc5..bbf079faea3d0 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaOptionalDecompositions.cpp TosaReduceTransposes.cpp TosaTypeConverters.cpp + TosaProfileCompliance.cpp TosaValidation.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp new file mode 100644 index 0000000000000..26960f0be6637 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -0,0 +1,476 @@ +//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" +#include "llvm/ADT/StringExtras.h" + +using namespace mlir; +using namespace mlir::tosa; + +TosaProfileCompliance::TosaProfileCompliance() { + const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1}; + const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4}; + const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8}; + const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16}; + const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32}; + const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48}; + const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16}; + const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16}; + const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32}; + const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; + const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + +// The profile-based compliance content below is auto-generated by a script +// in https://git.mlplatform.org/tosa/specification.git +#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h" + // End of auto-generated metadata +} + +template <> +OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() { + return profileComplianceMap; +} + +template <> +OperationExtensionComplianceMap +TosaProfileCompliance::getProfileComplianceMap() { + return extensionComplianceMap; +} + +// Base populating function +void ProfileInfoDepot::populateProfileInfo(ValueRange operands, Value output) { + for (auto operand : operands) + addValue(operand); + addValue(output); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) { + addValue(op.getInput1().front()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) { + addValue(op.getInput()); + addType(op.getAccType()); + addValue(op.getOutput()); +} + +template +void ProfileInfoDepot::populateProfileInfoConv(T op) { + addValue(op.getInput()); + addValue(op.getWeight()); + addValue(op.getBias()); + addType(op.getAccType()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) { + populateProfileInfoConv(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) { + populateProfileInfoConv(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) { + populateProfileInfoConv(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) { + populateProfileInfoConv(op); +} + +template +void ProfileInfoDepot::populateProfileInfoDataLayout(T op) { + addValue(op.getInput1()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) { + populateProfileInfoDataLayout(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) { + populateProfileInfoDataLayout(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) { + populateProfileInfoDataLayout(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) { + populateProfileInfoDataLayout(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) { + populateProfileInfoDataLayout(op); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) { + addValue(op.getValues()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) { + addValue(op.getValuesIn()); + addValue(op.getInput()); + addValue(op.getValuesOut()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) { + addValue(op.getInput1()); + addValue(op.getInput2()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { + addValue(op.getInput()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { + addValue(op.getInputReal()); + addValue(op.getInputImag()); + addValue(op.getOutputReal()); + addValue(op.getOutputImag()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { + addValue(op.getInput()); + addValue(op.getOutputReal()); + addValue(op.getOutputImag()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { + addValue(op.getInput2()); + addValue(op.getInput3()); + addValue(op.getOutput()); +} + +template <> +void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) { + addValue(op.getInput()); + addValue(op.getOutput()); +} + +LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { +// This helper function only populates the info for the customised operands. +#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \ + if (isa(op)) { \ + populateProfileInfo(cast(op)); \ + return success(); \ + } + +#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \ + if (isa(op)) \ + return success(); + +// This helper function populates the info for all operands. +#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ + if (isa(op)) { \ + populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return success(); \ + } + + // Skip irrelevant operands when they are independent and not tied to any + // specific profile/extension. + POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d) + POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D) + POPULATE_PROFILE_INFO_CUSTOM(Conv2D) + POPULATE_PROFILE_INFO_CUSTOM(Conv3D) + POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) + POPULATE_PROFILE_INFO_CUSTOM(Mul) + POPULATE_PROFILE_INFO_CUSTOM(FFT2d) + POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) + POPULATE_PROFILE_INFO_CUSTOM(Concat) + POPULATE_PROFILE_INFO_CUSTOM(Pad) + POPULATE_PROFILE_INFO_CUSTOM(Reshape) + POPULATE_PROFILE_INFO_CUSTOM(Slice) + POPULATE_PROFILE_INFO_CUSTOM(Tile) + POPULATE_PROFILE_INFO_CUSTOM(Transpose) + POPULATE_PROFILE_INFO_CUSTOM(Gather) + POPULATE_PROFILE_INFO_CUSTOM(Scatter) + POPULATE_PROFILE_INFO_CUSTOM(Resize) + POPULATE_PROFILE_INFO_CUSTOM(Select) + POPULATE_PROFILE_INFO_CUSTOM(Rescale) + + // Type Invariant Extension, a capability extension that is independent + // of the data type, meaning any compatible type can be used. No type + // constraint for those operations. + POPULATE_PROFILE_INFO_SKIP(ConstShape) + POPULATE_PROFILE_INFO_SKIP(Variable) + POPULATE_PROFILE_INFO_SKIP(VariableRead) + POPULATE_PROFILE_INFO_SKIP(VariableWrite) + POPULATE_PROFILE_INFO_SKIP(If) + POPULATE_PROFILE_INFO_SKIP(While) + POPULATE_PROFILE_INFO_SKIP(Yield) + + // For the most of tosa operators, all operands are profile/extension related + // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(Const) + POPULATE_PROFILE_INFO_COMMON(ArgMax) + POPULATE_PROFILE_INFO_COMMON(MatMul) + POPULATE_PROFILE_INFO_COMMON(Sub) + POPULATE_PROFILE_INFO_COMMON(Maximum) + POPULATE_PROFILE_INFO_COMMON(Minimum) + POPULATE_PROFILE_INFO_COMMON(MaxPool2d) + POPULATE_PROFILE_INFO_COMMON(Clamp) + POPULATE_PROFILE_INFO_COMMON(Erf) + POPULATE_PROFILE_INFO_COMMON(Sigmoid) + POPULATE_PROFILE_INFO_COMMON(Tanh) + POPULATE_PROFILE_INFO_COMMON(Add) + POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift) + POPULATE_PROFILE_INFO_COMMON(BitwiseAnd) + POPULATE_PROFILE_INFO_COMMON(BitwiseNot) + POPULATE_PROFILE_INFO_COMMON(BitwiseOr) + POPULATE_PROFILE_INFO_COMMON(BitwiseXor) + POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift) + POPULATE_PROFILE_INFO_COMMON(LogicalRightShift) + POPULATE_PROFILE_INFO_COMMON(LogicalAnd) + POPULATE_PROFILE_INFO_COMMON(LogicalNot) + POPULATE_PROFILE_INFO_COMMON(LogicalOr) + POPULATE_PROFILE_INFO_COMMON(LogicalXor) + POPULATE_PROFILE_INFO_COMMON(IntDiv) + POPULATE_PROFILE_INFO_COMMON(Pow) + POPULATE_PROFILE_INFO_COMMON(Table) + POPULATE_PROFILE_INFO_COMMON(Abs) + POPULATE_PROFILE_INFO_COMMON(Ceil) + POPULATE_PROFILE_INFO_COMMON(Clz) + POPULATE_PROFILE_INFO_COMMON(Sin) + POPULATE_PROFILE_INFO_COMMON(Cos) + POPULATE_PROFILE_INFO_COMMON(Exp) + POPULATE_PROFILE_INFO_COMMON(Floor) + POPULATE_PROFILE_INFO_COMMON(Log) + POPULATE_PROFILE_INFO_COMMON(Negate) + POPULATE_PROFILE_INFO_COMMON(Reciprocal) + POPULATE_PROFILE_INFO_COMMON(Rsqrt) + POPULATE_PROFILE_INFO_COMMON(ReduceAll) + POPULATE_PROFILE_INFO_COMMON(ReduceAny) + POPULATE_PROFILE_INFO_COMMON(ReduceMax) + POPULATE_PROFILE_INFO_COMMON(ReduceMin) + POPULATE_PROFILE_INFO_COMMON(ReduceProd) + POPULATE_PROFILE_INFO_COMMON(ReduceSum) + POPULATE_PROFILE_INFO_COMMON(Equal) + POPULATE_PROFILE_INFO_COMMON(GreaterEqual) + POPULATE_PROFILE_INFO_COMMON(Greater) + POPULATE_PROFILE_INFO_COMMON(Reverse) + POPULATE_PROFILE_INFO_COMMON(Identity) + + return failure(); +} + +//===----------------------------------------------------------------------===// +// Tosa Profile And Extension Compliance Checker +//===----------------------------------------------------------------------===// + +template +LogicalResult TosaProfileCompliance::checkProfileOrExtension( + Operation *op, const tosa::TargetEnv &targetEnv, + const SmallVector> &specRequiredModeSet) { + + // None of profile requirement is set in the specification. + if (specRequiredModeSet.size() == 0) + return success(); + + auto opName = op->getName().getStringRef().str(); + auto compMap = getProfileComplianceMap(); + auto it = compMap.find(opName); + + if (it == compMap.end()) { + // Operators such as variable and shape ops do not have an operand type + // restriction. When the profile compliance information of operation is not + // found, confirm if the target have enabled the profile required from the + // specification. + int mode_count = 0; + for (const auto &cands : specRequiredModeSet) { + if (targetEnv.allowsAnyOf(cands)) + return success(); + mode_count += cands.size(); + } + + op->emitOpError() << "illegal: requires" + << (mode_count > 1 ? " any of " : " ") << "[" + << llvm::join(stringifyProfile(specRequiredModeSet), + ", ") + << "] but not enabled in target\n"; + + return failure(); + } + + CheckCondition condition = CheckCondition::invalid; + // Find the profiles or extensions requirement according to the signature of + // type of the operand list. + SmallVector opRequiredMode = + findMatchedProfile(op, it->second, condition); + + if (opRequiredMode.size() == 0) { + // No matched restriction found. + return success(); + } + + if (condition == CheckCondition::allOf && + !targetEnv.allowsAllOf(opRequiredMode)) { + op->emitOpError() << "illegal: requires" + << (opRequiredMode.size() > 1 ? " all of " : " ") << "[" + << llvm::join(stringifyProfile(opRequiredMode), ", ") + << "] but not enabled in target\n"; + return failure(); + } + + if (condition == CheckCondition::anyOf && + !targetEnv.allowsAnyOf(opRequiredMode)) { + op->emitOpError() << "illegal: requires" + << (opRequiredMode.size() > 1 ? " any of " : " ") << "[" + << llvm::join(stringifyProfile(opRequiredMode), ", ") + << "] but not enabled in target\n"; + return failure(); + } + + // Each extension can contain a list of profiles that it works with, usually + // have the same data type. + if constexpr (std::is_same_v) { + for (const auto &mode : opRequiredMode) { + SmallVector coProfs = getCooperativeProfiles(mode); + if (!targetEnv.allowsAnyOf(coProfs)) { + op->emitOpError() << "illegal: requires [" + << llvm::join(stringifyProfile(coProfs), + ", ") + << "] to work with but not enabled in target\n"; + return failure(); + } + } + } + + // Ensure the profile inference match the profile knowledge of the + // specification. + for (const auto &cands : specRequiredModeSet) { + for (size_t i = 0; i < opRequiredMode.size(); i++) { + if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) == + cands.end()) { + op->emitOpError() << "illegal: requires [" + << llvm::join(stringifyProfile(opRequiredMode), + ", ") + << "] but not included in the profile compliance [" + << llvm::join( + stringifyProfile(specRequiredModeSet), ", ") + << "]\n"; + return failure(); + } + } + } + + return success(); +} + +LogicalResult +TosaProfileCompliance::checkProfile(Operation *op, + const tosa::TargetEnv &targetEnv) { + if (auto interface = dyn_cast(op)) + return checkProfileOrExtension(op, targetEnv, + interface.getProfiles()); + + return success(); +} + +LogicalResult +TosaProfileCompliance::checkExtension(Operation *op, + const tosa::TargetEnv &targetEnv) { + if (auto interface = dyn_cast(op)) + return checkProfileOrExtension(op, targetEnv, + interface.getExtensions()); + + return success(); +} + +// Find the profiles or extensions requirement according to the signature of +// type of the operand list. +template +SmallVector TosaProfileCompliance::findMatchedProfile( + Operation *op, SmallVector> compInfo, + CheckCondition &condition) { + assert(compInfo.size() != 0); + + // Populate the type of profile/extension relevant operands. + ProfileInfoDepot depot(op); + SmallVector present = depot.getInfo(); + if (present.size() == 0) + return {}; + + for (size_t i = 0; i < compInfo.size(); i++) { + SmallVector> sets = compInfo[i].operandTypeInfoSet; + + for (SmallVector expected : sets) { + assert(present.size() == expected.size()); + + bool is_found = true; + // Compare the type signature between the given operation and the + // compliance metadata. + for (size_t j = 0; j < expected.size(); j++) { + if (!isSameTypeInfo(present[j], expected[j])) { + // Verify the next mode set from the list. + is_found = false; + break; + } + } + + if (is_found == true) { + condition = compInfo[i].condition; + return compInfo[i].mode; + } + } + } + + return {}; +} + +// Debug utilites. +template +SmallVector +TosaProfileCompliance::stringifyProfile(ArrayRef profiles) { + SmallVector debugStrings; + for (const auto &profile : profiles) { + if constexpr (std::is_same_v) + debugStrings.push_back(tosa::stringifyProfile(profile)); + else + debugStrings.push_back(tosa::stringifyExtension(profile)); + } + return debugStrings; +} + +template +SmallVector TosaProfileCompliance::stringifyProfile( + const SmallVector> &profileSet) { + SmallVector debugStrings; + + for (const auto &profiles : profileSet) { + auto tempStrings = stringifyProfile(profiles); + debugStrings.insert(debugStrings.end(), tempStrings.begin(), + tempStrings.end()); + } + + return debugStrings; +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index f4abe628d37d1..f74a4b4c58b80 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" @@ -25,6 +27,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringExtras.h" namespace mlir { namespace tosa { @@ -86,10 +89,12 @@ static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0}; struct TosaValidation : public tosa::impl::TosaValidationBase { public: explicit TosaValidation() { populateConstantOperandChecks(); } + explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { this->profile = options.profile; - this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment; + this->extension = options.extension; + this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->level = options.level; } void runOnOperation() final; @@ -401,9 +406,27 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (!profile.empty()) { for (std::string &prof : profile) { - auto profSymbol = symbolizeTosaProfileEnum(prof); + auto profSymbol = symbolizeProfile(prof); if (profSymbol) { - enabled_profiles.push_back(profSymbol.value()); + targetEnv.addProfile(profSymbol.value()); + } else { + llvm::errs() << "unknown TOSA profile name passed in: " << prof + << ", supported profiles are `pro_int` and `pro_fp`\n"; + return signalPassFailure(); + } + } + } + + if (!extension.empty()) { + for (std::string &ext : extension) { + auto extSymbol = symbolizeExtension(ext); + if (extSymbol) { + targetEnv.addExtension(extSymbol.value()); + } else { + llvm::errs() << "unknown TOSA extension name passed in: " << ext + << ", supported extension are int16, int4, bf16, " + << "fp8e4m3, fp8e5m2, fft, and variable\n"; + return signalPassFailure(); } } } @@ -411,17 +434,13 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { bool CheckVariable(Operation *op); bool CheckVariableReadOrWrite(Operation *op); - bool isValidElementType(Type type); - bool isEnabledProfile(TosaProfileEnum prof) { - return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) != - std::end(enabled_profiles); - } SmallVector> constCheckers; - SmallVector enabled_profiles; TosaLevel tosaLevel; DenseMap variablesMap; + TosaProfileCompliance profileComp; + tosa::TargetEnv targetEnv; }; LogicalResult TosaValidation::applyLevelCheck(Operation *op) { @@ -677,8 +696,6 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type) { if (isa(type)) { - if (!isEnabledProfile(TosaProfileEnum::MainInference)) - return false; return type.isF32() || type.isF16() || type.isBF16(); } else if (auto intTy = dyn_cast(type)) { if (intTy.isSignless()) { @@ -709,6 +726,15 @@ void TosaValidation::runOnOperation() { if (op->getDialect() != tosaDialect) return; + // Profile-Extension based validation should be performed at the beginning. + if (strictOpSpecAlignment && + failed(profileComp.checkProfile(op, targetEnv))) + return signalPassFailure(); + + if (strictOpSpecAlignment && + failed(profileComp.checkExtension(op, targetEnv))) + return signalPassFailure(); + for (Value operand : op->getOperands()) { auto elementTy = getElementTypeOrSelf(operand); if (!isValidElementType(elementTy)) { @@ -728,7 +754,7 @@ void TosaValidation::runOnOperation() { // Some uses of TOSA rely on the constant operands of particular // operations. - if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) + if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op))) signalPassFailure(); // do level checks @@ -740,7 +766,7 @@ void TosaValidation::runOnOperation() { signalPassFailure(); // do error if checks - if (StrictOperationSpecAlignment && failed(applyErrorIfCheck(op))) + if (strictOpSpecAlignment && failed(applyErrorIfCheck(op))) signalPassFailure(); }); } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir new file mode 100644 index 0000000000000..e66ff4cacfd89 --- /dev/null +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -0,0 +1,684 @@ +//-------------------------------------------------------------------------------------------------- +// Test whether the supported profile and extension are attached to the operation properly. +// The data type of arguments of operation are irrelevant in this test. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt -mlir-disable-threading -test-tosa-op-availability %s | FileCheck %s + +// ----- +// CHECK-LABEL: argmax +func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<14x19xf32>) -> tensor<14xi32> + return %0 : tensor<14xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d +func.func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +// CHECK-LABEL: conv2d +func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: conv3d +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> + return %0 : tensor<1x4x8x21x34xf32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: fft2d +func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) { + // CHECK: profiles: [ ] + // CHECK: extensions: [ [fft] ] + %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) + return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32> +} + +// ----- +// CHECK-LABEL: matmul +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +// CHECK-LABEL: max_pool2d_f32 +func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +// CHECK-LABEL: rfft2d +func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) { + // CHECK: profiles: [ ] + // CHECK: extensions: [ [fft] ] + %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) + return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32> +} + +// ----- +// CHECK-LABEL: transpose_conv2d +func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +// CHECK-LABEL: clamp +func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, bf16] ] + %0 = tosa.clamp %arg0 {min_val = -3.40282347E+38 : f32, max_val = 3.40282347E+38 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sigmoid +func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: tanh +func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.tanh %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} +// ----- +// CHECK-LABEL: erf +func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.erf %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: add +func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: arithmetic_right_shift +func.func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: bitwise_and +func.func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.bitwise_and %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_or +func.func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.bitwise_or %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_xor +func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.bitwise_xor %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: int_div +func.func @test_int_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %0 = tosa.int_div %arg0, %arg1 : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_and +func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_left_shift +func.func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_left_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_right_shift +func.func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_right_shift %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_or +func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_or %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_xor +func.func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_xor %arg0, %arg1 : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: maximum +func.func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.maximum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: minimum +func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.minimum %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: mul +func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: pow +func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.pow %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sub +func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: table +func.func @test_table(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.table %arg0, %arg1 : (tensor<64xi32>, tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> + return %0 : tensor<64x!quant.uniform> +} + +// ----- +// CHECK-LABEL: abs +func.func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: bitwise_not +func.func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.bitwise_not %arg0 : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32> + return %0 : tensor<13x21x1xi32> +} + +// ----- +// CHECK-LABEL: ceil +func.func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.ceil %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: clz +func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ ] + %0 = tosa.clz %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: cos +func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: exp +func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: floor +func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.floor %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: log +func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.log %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: logical_not +func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: negate +func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reciprocal +func.func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.reciprocal %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: rsqrt +func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.rsqrt %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sin +func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: select +func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: equal +func.func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.equal %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater +func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.greater %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater_equal +func.func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.greater_equal %arg0, %arg1 : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_all +func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_any +func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %0 = tosa.reduce_any %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_max +func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + return %0 : tensor<1x21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_min +func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + return %0 : tensor<1x21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_product +func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> { + // CHECK: profiles: [ [pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.reduce_prod %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + return %0 : tensor<1x21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_sum +func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<1x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + return %0 : tensor<1x21x3xf32> +} + +// ----- +// CHECK-LABEL: concat +func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +// CHECK-LABEL: pad +func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %padding = tosa.const_shape {value = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.pad %arg0, %padding : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reshape +func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { + %1 = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32> + return %0 : tensor<1x819xf32> +} + +// ----- +// CHECK-LABEL: reverse +func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: slice +func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { + %0 = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32> + return %2 : tensor<4x11x1xf32> +} + +// ----- +// CHECK-LABEL: tile +func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> { + %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32> + return %0 : tensor<39x21x6xf32> +} + +// ----- +// CHECK-LABEL: transpose +func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %1 = tosa.transpose %arg0, %0 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> + return %1 : tensor<3x13x21xf32> +} + +// ----- +// CHECK-LABEL: gather +func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> +} + +// ----- +// CHECK-LABEL: scatter +func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: resize +func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { + %scale = tosa.const_shape { value = dense<[4, 2, 4, 2]> : tensor<4xindex> } : () -> !tosa.shape<4> + %offset = tosa.const_shape { value = dense<[-1, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %border = tosa.const_shape { value = dense<[1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, bf16] ] + %1 = tosa.resize %arg0, %scale, %offset, %border {mode = "BILINEAR"} : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x64x64x8xf32> + return %1 : tensor<1x64x64x8xf32> +} + +// ----- +// CHECK-LABEL: cast +func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: rescale +func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> { + // CHECK: profiles: [ [pro_int] ] + // CHECK: extensions: [ [int16] ] + %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: const +func.func @test_const(%arg0 : index) -> tensor<4xi32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- +// CHECK-LABEL: identity +func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] + %0 = tosa.identity %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: cond_if +func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- +// CHECK-LABEL: while_loop +func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ [bf16] ] + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg3, %arg1 : (tensor, tensor) -> tensor + %3 = tosa.logical_not %2 : (tensor) -> tensor + tosa.yield %3 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg3, %2 : (tensor, tensor) -> tensor + %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> + %4 = tosa.reshape %2, %7 : (tensor, !tosa.shape<1>) -> tensor<1xi32> + %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + %6 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %6, %3, %5 : tensor, tensor, tensor<10xi32> + } + return +} + +// ----- +// CHECK-LABEL: custom +func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>) + return %0 : tensor<10xi32> +} + +// ----- +// CHECK-LABEL: const_shape +func.func @test_const_shape() -> !tosa.shape<4> { + // CHECK: profiles: [ [pro_int, pro_fp] ] + // CHECK: extensions: [ ] + %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + return %cst : !tosa.shape<4> +} + diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 1307da88d1e64..1aa8547cb2fdb 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,8 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt strict-op-spec-alignment" - +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" func.func @test_const() -> tensor<1xf32> { // expected-error@+1{{'tosa.const' op expected same attr/result element types}} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir new file mode 100644 index 0000000000000..046b9d5615074 --- /dev/null +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -0,0 +1,38 @@ +//-------------------------------------------------------------------------------------------------- +// Enable all supported profiles to focus the verification of expected extension requirement errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment" + +// ----- +func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) { + // expected-error@+1 {{'tosa.fft2d' op illegal: requires [fft] but not enabled in target}} + %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) + return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32> +} + +// ----- +func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () { + // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}} + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.read' op illegal: requires [variable]}} + %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16> + return +} + +// ----- +func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { + // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}} + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.write' op illegal: requires [variable]}} + tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16> + return +} + +// ----- +func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32> { + // expected-error@+1 {{'tosa.cast' op illegal: requires [bf16] but not enabled in target}} + %0 = tosa.cast %arg0 : (tensor<13x21x3xbf16>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 6f49195d30e97..90c4551564d1e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,9 +1,8 @@ //-------------------------------------------------------------------------------------------------- -// Enable all supported profiles to focus the verification of expected level errors. +// Enable all supported profiles and extensions to focus the verification of expected level errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=bi,mi,mt" - +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable" func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir new file mode 100644 index 0000000000000..6dddcf329d110 --- /dev/null +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -0,0 +1,83 @@ +//-------------------------------------------------------------------------------------------------- +// Enable all supported extensions to focus the verification of expected profile requirement errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" + +// ----- +func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () { + // expected-error@+1 {{'tosa.table' op illegal: requires [pro_int] but not enabled in target}} + %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor + return +} + +// ----- +func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + // expected-error@+1 {{'tosa.conv2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { + // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.sigmoid' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.add' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<1x21x3xi1> { + // expected-error@+1 {{'tosa.reduce_all' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { + // expected-error@+1 {{'tosa.concat' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +func.func @test_cast_i32_f32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.cast' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> { + // expected-error@+1 {{'tosa.custom' op illegal: requires any of [pro_int, pro_fp] but not enabled in target}} + %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>) + return %0 : tensor<10xi32> +} + diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir new file mode 100644 index 0000000000000..c46b2543fbed5 --- /dev/null +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -0,0 +1,62 @@ +//-------------------------------------------------------------------------------------------------- +// Enable all supported extensions to focus the verification of expected profile requirement errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" + +// ----- +func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + // expected-error@+1 {{'tosa.conv2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { + // expected-error@+1 {{'tosa.matmul' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.matmul %arg0, %arg1 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- +func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.sigmoid' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.add' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +func.func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { + // expected-error@+1 {{'tosa.concat' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +func.func @test_cast_i32_f32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.cast' op illegal: requires [pro_fp] but not enabled in target}} + %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir new file mode 100644 index 0000000000000..479b7569f54ae --- /dev/null +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -0,0 +1,26 @@ +//-------------------------------------------------------------------------------------------------- +// Enable all supported extensions to focus the verification of expected profile requirement errors. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" + +// ----- +func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () { + // expected-error@+1 {{'tosa.table' op illegal: requires [pro_int] but not enabled in target}} + %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor + return +} + +// ----- +func.func @test_reduce_max(%arg0: tensor<13x21x3xi16>) -> tensor<1x21x3xi16> { + // expected-error@+1 {{'tosa.reduce_max' op illegal: requires [pro_int] but not enabled in target}} + %0 = tosa.reduce_max %arg0 {axis = 0 : i32} : (tensor<13x21x3xi16>) -> tensor<1x21x3xi16> + return %0 : tensor<1x21x3xi16> +} + +// ----- +func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> { + // expected-error@+1 {{'tosa.cast' op illegal: requires [pro_int] but not enabled in target}} + %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8> + return %0 : tensor<13x21x3xi8> +} diff --git a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt index 7d40881ee6ee4..43f0d0d21c1c0 100644 --- a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTestPasses TosaTestPasses.cpp + TestAvailability.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp b/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp new file mode 100644 index 0000000000000..bec563d1ec747 --- /dev/null +++ b/mlir/test/lib/Dialect/Tosa/TestAvailability.cpp @@ -0,0 +1,78 @@ +//===- TestAvailability.cpp - Pass to test Tosa op availability ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Printing op availability pass +//===----------------------------------------------------------------------===// + +namespace { +/// A pass for testing Tosa op availability. +struct PrintOpAvailability + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability) + + void runOnOperation() override; + StringRef getArgument() const final { return "test-tosa-op-availability"; } + StringRef getDescription() const final { return "Test Tosa op availability"; } +}; +} // namespace + +void PrintOpAvailability::runOnOperation() { + auto f = getOperation(); + llvm::outs() << f.getName() << "\n"; + + Dialect *tosaDialect = getContext().getLoadedDialect("tosa"); + + f->walk([&](Operation *op) { + if (op->getDialect() != tosaDialect) + return WalkResult::advance(); + + auto opName = op->getName(); + auto &os = llvm::outs(); + + if (auto profile = dyn_cast(op)) { + os << opName << " profiles: ["; + for (const auto &profs : profile.getProfiles()) { + os << " ["; + llvm::interleaveComma(profs, os, [&](tosa::Profile prof) { + os << tosa::stringifyProfile(prof); + }); + os << "]"; + } + os << " ]\n"; + } + + if (auto extension = dyn_cast(op)) { + os << opName << " extensions: ["; + for (const auto &exts : extension.getExtensions()) { + os << " ["; + llvm::interleaveComma(exts, os, [&](tosa::Extension ext) { + os << tosa::stringifyExtension(ext); + }); + os << "]"; + } + os << " ]\n"; + } + + os.flush(); + + return WalkResult::advance(); + }); +} + +namespace mlir { +void registerPrintTosaAvailabilityPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 74007d01347ae..f18ad45dfb708 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -39,6 +39,7 @@ void registerLoopLikeInterfaceTestPasses(); void registerPassManagerTestPass(); void registerPrintSpirvAvailabilityPass(); void registerRegionTestPasses(); +void registerPrintTosaAvailabilityPass(); void registerShapeFunctionTestPasses(); void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); @@ -175,6 +176,7 @@ void registerTestTransformDialectExtension(DialectRegistry &); void registerTestPasses() { registerCloneTestPasses(); registerConvertToTargetEnvPass(); + registerPrintTosaAvailabilityPass(); registerLazyLoadingTestPasses(); registerLoopLikeInterfaceTestPasses(); registerPassManagerTestPass(); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index fb507dc7f8c3c..9431c59860522 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -32,6 +32,7 @@ add_tablegen(mlir-tblgen MLIR PassGen.cpp RewriterGen.cpp SPIRVUtilsGen.cpp + TosaUtilsGen.cpp ) target_link_libraries(mlir-tblgen diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp new file mode 100644 index 0000000000000..491f9143edb02 --- /dev/null +++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp @@ -0,0 +1,226 @@ +//===- TosaUtilsGen.cpp - Tosa utility generator -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TosaUtilsGen generates common utility functions for Tosa validation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#include +#include + +using llvm::ArrayRef; +using llvm::formatv; +using llvm::raw_ostream; +using llvm::raw_string_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::SmallVector; +using llvm::SMLoc; +using llvm::StringMap; +using llvm::StringRef; +using mlir::tblgen::Attribute; +using mlir::tblgen::EnumAttr; +using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::NamedAttribute; +using mlir::tblgen::NamedTypeConstraint; +using mlir::tblgen::NamespaceEmitter; +using mlir::tblgen::Operator; + +//===----------------------------------------------------------------------===// +// Availability Wrapper Class +//===----------------------------------------------------------------------===// + +namespace { +// Wrapper class with helper methods for accessing availability defined in +// TableGen. +class Availability { +public: + explicit Availability(const Record *def); + + // Returns the name of the direct TableGen class for this availability + // instance. + StringRef getClass() const; + + // Returns the name of the query function insided the generated C++ interface. + StringRef getQueryFnName() const; + + // Returns the return type of the query function insided the generated C++ + // interface. + StringRef getQueryFnRetType() const; + + // Returns the code for merging availability requirements. + StringRef getMergeActionCode() const; + + // Returns the initializer expression for initializing the final availability + // requirements. + StringRef getMergeInitializer() const; + + // Returns the C++ statements for preparing availability instance. + StringRef getMergeInstancePreparation() const; + + // Returns the concrete availability instance carried in this case. + StringRef getMergeInstance() const; + + // Returns the underlying LLVM TableGen Record. + const llvm::Record *getDef() const { return def; } + +private: + // The TableGen definition of this availability. + const llvm::Record *def; +}; +} // namespace + +Availability::Availability(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Availability") && + "must be subclass of TableGen 'Availability' class"); +} + +StringRef Availability::getClass() const { + SmallVector parentClass; + def->getDirectSuperClasses(parentClass); + if (parentClass.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + return parentClass.front()->getName(); +} + +StringRef Availability::getQueryFnRetType() const { + return def->getValueAsString("queryFnRetType"); +} + +StringRef Availability::getQueryFnName() const { + return def->getValueAsString("queryFnName"); +} + +StringRef Availability::getMergeActionCode() const { + return def->getValueAsString("mergeAction"); +} + +StringRef Availability::getMergeInitializer() const { + return def->getValueAsString("initializer"); +} + +StringRef Availability::getMergeInstancePreparation() const { + return def->getValueAsString("instancePreparation"); +} + +StringRef Availability::getMergeInstance() const { + return def->getValueAsString("instance"); +} + +// Returns the availability spec of the given `def`. +std::vector getAvailabilities(const Record &def) { + std::vector availabilities; + + if (def.getValue("availability")) { + std::vector availDefs = + def.getValueAsListOfDefs("availability"); + availabilities.reserve(availDefs.size()); + for (const Record *avail : availDefs) + availabilities.emplace_back(avail); + } + + return availabilities; +} + +//===----------------------------------------------------------------------===// +// Tosa Availability Impl AutoGen +//===----------------------------------------------------------------------===// + +static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { + mlir::tblgen::FmtContext fctx; + fctx.addSubst("overall", "tblgen_overall"); + + std::vector opAvailabilities = + getAvailabilities(srcOp.getDef()); + + // First collect all availability classes this op should implement. + // All availability instances keep information for the generated interface and + // the instance's specific requirement. Here we remember a random instance so + // we can get the information regarding the generated interface. + llvm::StringMap availClasses; + for (const Availability &avail : opAvailabilities) + availClasses.try_emplace(avail.getClass(), avail); + + // Then generate implementation for each availability class. + for (const auto &availClass : availClasses) { + StringRef availClassName = availClass.getKey(); + Availability avail = availClass.getValue(); + + // Generate the implementation method signature. + os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(), + srcOp.getCppClassName(), avail.getQueryFnName()); + + // Create the variable for the final requirement and initialize it. + os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(), + avail.getMergeInitializer()); + + // Update with the op's specific availability spec. + for (const Availability &avail : opAvailabilities) + if (avail.getClass() == availClassName && + (!avail.getMergeInstancePreparation().empty() || + !avail.getMergeActionCode().empty())) { + os << " {\n " + // Prepare this instance. + << avail.getMergeInstancePreparation() + << "\n " + // Merge this instance. + << std::string( + tgfmt(avail.getMergeActionCode(), + &fctx.addSubst("instance", avail.getMergeInstance()))) + << ";\n }\n"; + } + + os << " return tblgen_overall;\n"; + os << "}\n"; + } +} + +static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Tosa Op Availability Implementations", os, + recordKeeper); + + auto defs = recordKeeper.getAllDerivedDefinitions("Tosa_Op"); + for (const auto *def : defs) { + Operator op(def); + if (def->getValueAsBit("autogenAvailability")) + emitAvailabilityImpl(op, os); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Op Availability Implementation Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genOpAvailabilityImpl("gen-tosa-avail-impls", + "Generate Tosa operation utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitAvailabilityImpl(records, os); + }); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 05385ba491525..6abfd29884d46 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -12241,6 +12241,42 @@ gentbl_cc_library( deps = [":TosaDialectTdFiles"], ) +gentbl_cc_library( + name = "MLIRTosaEnumsIncGen", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Tosa/IR/TosaEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Tosa/IR/TosaOpBase.td", +) + +gentbl_cc_library( + name = "MLIRTosaAvailabilityIncGen", + tbl_outs = [ + ( + ["-gen-avail-interface-decls"], + "include/mlir/Dialect/Tosa/IR/TosaAvailability.h.inc", + ), + ( + ["-gen-avail-interface-defs"], + "include/mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc", + ), + ( + ["-gen-tosa-avail-impls"], + "include/mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Tosa/IR/TosaOps.td", +) + gentbl_cc_library( name = "TosaDialectBytecodeGen", strip_include_prefix = "include",