From 4433e5be2992e932d9c0177588f0b0d947658435 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 8 Aug 2024 03:01:08 +0000 Subject: [PATCH] Add f8E4M3 and f8E3M4 types support --- docs/spec.md | 6 +- stablehlo/dialect/Base.td | 4 +- stablehlo/dialect/Version.h | 2 +- stablehlo/dialect/VhloBytecode.cpp | 22 + stablehlo/dialect/VhloDialect.td | 1 + stablehlo/dialect/VhloTypes.cpp | 12 + stablehlo/dialect/VhloTypes.td | 6 + stablehlo/reference/Tensor.cpp | 22 +- stablehlo/reference/Types.cpp | 3 +- stablehlo/tests/interpret/constant.mlir | 16 + stablehlo/tests/interpret/dot_general.mlir | 23 + stablehlo/tests/ops_stablehlo.mlir | 32 +- stablehlo/tests/ops_stablehlo_quantized.mlir | 106 +- stablehlo/tests/ops_stablehlo_roundtrip.mlir | 2 + .../stablehlo_legalize_to_vhlo.1_6_0.mlir | 2851 +++++++++++++++++ .../stablehlo_legalize_to_vhlo.1_6_0.mlir.bc | Bin 0 -> 18992 bytes .../vhlo/stablehlo_legalize_to_vhlo.mlir | 16 + 17 files changed, 3052 insertions(+), 72 deletions(-) create mode 100644 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir create mode 100644 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir.bc diff --git a/docs/spec.md b/docs/spec.md index 57c0619550..b6b8a3ae14 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -245,8 +245,8 @@ BooleanType ::= 'i1' IntegerType ::= SignedIntegerType | UnsignedIntegerType SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64' UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64' -FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ' - | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64' +FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' + | 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64' ComplexType ::= 'complex' '<' ComplexElementType '>' ComplexElementType ::= 'f32' | 'f64' ``` @@ -264,6 +264,8 @@ values of type `tensor`). inclusive, and unsigned `uiN` types represent integer values from `0` to `2^N-1` inclusive. * **Floating-point types** can be one of the following: + * `f8E3M4`, `f8E4M3` and `f8E5M2` 8-bit floating point numbers following + IEEE-754 conventions. * `f8E4M3FN` and `f8E5M2` types corresponding to respectively the `E4M3` and `E5M2` encodings of the FP8 format described in [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index c58c122907..296b118c4a 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -42,8 +42,8 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>; def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>; def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; -def HLO_Float : AnyTypeOf<[F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, - F8E5M2FNUZ, F16, F32, F64, BF16]>; +def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, + F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>; def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>; def HLO_Complex : Complex>; diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index b4b2a646a2..86fff00252 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(1, 5, 3); } + static Version getCurrentVersion() { return Version(1, 6, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index 18965ef8f7..9352d3cc9f 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -215,6 +215,14 @@ enum TypeCode { /// } kFloatF64V1Type = 5, + /// FloatF8E3M4V1Type { + /// } + kFloatF8E3M4V1Type = 34, + + /// FloatF8E4M3V1Type { + /// } + kFloatF8E4M3V1Type = 33, + /// FloatF8E4M3FNV1Type { /// } kFloatF8E4M3FNV1Type = 6, @@ -689,9 +697,11 @@ const llvm::fltSemantics &getFloatSemantics(Type type) { if (isa(type)) return APFloat::IEEEhalf(); if (isa(type)) return APFloat::IEEEsingle(); if (isa(type)) return APFloat::IEEEdouble(); + if (isa(type)) return APFloat::Float8E3M4(); if (isa(type)) return APFloat::Float8E4M3FNUZ(); if (isa(type)) return APFloat::Float8E4M3B11FNUZ(); if (isa(type)) return APFloat::Float8E4M3FN(); + if (isa(type)) return APFloat::Float8E4M3(); if (isa(type)) return APFloat::Float8E5M2FNUZ(); if (isa(type)) return APFloat::Float8E5M2(); llvm::report_fatal_error("unsupported floating-point type"); @@ -958,6 +968,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF64V1Type::get(getContext()); case vhlo_encoding::kFloatF8E5M2V1Type: return FloatF8E5M2V1Type::get(getContext()); + case vhlo_encoding::kFloatF8E4M3V1Type: + return FloatF8E4M3V1Type::get(getContext()); case vhlo_encoding::kFloatF8E4M3FNV1Type: return FloatF8E4M3FNV1Type::get(getContext()); case vhlo_encoding::kFloatF8E5M2FNUZV1Type: @@ -966,6 +978,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF8E4M3FNUZV1Type::get(getContext()); case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type: return FloatF8E4M3B11FNUZV1Type::get(getContext()); + case vhlo_encoding::kFloatF8E3M4V1Type: + return FloatF8E3M4V1Type::get(getContext()); case vhlo_encoding::kFunctionV1Type: return readFunctionV1Type(reader); case vhlo_encoding::kIndexV1Type: @@ -1046,6 +1060,14 @@ LogicalResult VhloBytecodeInterface::writeType( LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success(); }) + .Case([&](FloatF8E3M4V1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success(); + }) + .Case([&](FloatF8E4M3V1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3V1Type), success(); + }) .Case([&](FloatF8E4M3FNV1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNV1Type), diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index 6e10613d78..0c09b39337 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -43,6 +43,7 @@ def VHLO_Dialect : Dialect { 1.3.0: Extend `custom_call` op `backend_config` to support `DictionaryAttr`. 1.4.0: Add `tan` op to StableHLO opset. 1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic. + 1.6.0: Introduce `f8E4M3` and `f8E3M4` types. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.cpp b/stablehlo/dialect/VhloTypes.cpp index 170a099062..0dd64668a7 100644 --- a/stablehlo/dialect/VhloTypes.cpp +++ b/stablehlo/dialect/VhloTypes.cpp @@ -84,6 +84,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { [&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); }); addConversion( [&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); }); + addConversion([&](Float8E3M4Type type) { + return FloatF8E3M4V1Type::get(type.getContext()); + }); + addConversion([&](Float8E4M3Type type) { + return FloatF8E4M3V1Type::get(type.getContext()); + }); addConversion([&](Float8E4M3FNType type) { return FloatF8E4M3FNV1Type::get(type.getContext()); }); @@ -171,6 +177,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { [&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); }); addConversion( [&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); }); + addConversion([&](FloatF8E3M4V1Type type) { + return Float8E3M4Type::get(type.getContext()); + }); + addConversion([&](FloatF8E4M3V1Type type) { + return Float8E4M3Type::get(type.getContext()); + }); addConversion([&](FloatF8E4M3FNV1Type type) { return Float8E4M3FNType::get(type.getContext()); }); diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index caa809fad8..a3d8d0bda1 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -79,6 +79,12 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">; // Corresponds to the 'f64' FloatType from the StableHLO spec. def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">; +// Corresponds to the 'f8E3M4' FloatType from the StableHLO spec. +def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.6.0", "current">; + +// Corresponds to the 'f8E4M3' FloatType from the StableHLO spec. +def VHLO_FloatF8E4M3V1 : VHLO_TypeDef<"FloatF8E4M3V1", "f8E4M3_v1", "1.6.0", "current">; + // Corresponds to the 'f8E4M3FN' FloatType from the StableHLO spec. def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0", "current">; diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index a1e66eef6c..97ac92e3a3 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -118,11 +118,21 @@ Element Tensor::get(const Index &index) const { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. + if (elementType.isFloat8E3M4()) { + auto elementData = reinterpret_cast(elementPtr); + return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(), + APInt(8, *elementData))); + } if (elementType.isFloat8E4M3B11FNUZ()) { auto elementData = reinterpret_cast(elementPtr); return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(), APInt(8, *elementData))); } + if (elementType.isFloat8E4M3()) { + auto elementData = reinterpret_cast(elementPtr); + return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(), + APInt(8, *elementData))); + } if (elementType.isFloat8E4M3FN()) { auto elementData = reinterpret_cast(elementPtr); return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(), @@ -252,7 +262,8 @@ void Tensor::set(const Index &index, const Element &element) { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. - if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() || + if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() || + elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() || elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() || elementType.isFloat8E5M2FNUZ()) { auto elementData = reinterpret_cast(elementPtr); @@ -446,7 +457,8 @@ Tensor makeTensor(DenseElementsAttr attr) { auto elementType = type.getElementType(); // Handle floating-point types. - if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() || + if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() || + elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() || elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() || elementType.isFloat8E5M2FNUZ()) { auto floatValues = llvm::map_to_vector( @@ -454,9 +466,9 @@ Tensor makeTensor(DenseElementsAttr attr) { return value.bitcastToAPInt().getZExtValue(); }); - // For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ - // floating-point types, we use uint8_t as their storage type because there - // are no builtin types for those. + // For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and + // f8E5M2FNUZ floating-point types, we use uint8_t as their storage type + // because there are no builtin types for those. return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( floatValues)); } diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index e55c520793..9944ca07c1 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -48,7 +48,8 @@ bool isSupportedIntegerType(Type type) { } bool isSupportedFloatType(Type type) { - return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || + return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() || + type.isFloat8E4M3() || type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() || type.isF32() || type.isF64(); diff --git a/stablehlo/tests/interpret/constant.mlir b/stablehlo/tests/interpret/constant.mlir index 26f00eb65a..2e24ba02f8 100644 --- a/stablehlo/tests/interpret/constant.mlir +++ b/stablehlo/tests/interpret/constant.mlir @@ -96,6 +96,14 @@ func.func @constant_op_test_ui64() { // ----- +func.func @constant_op_test_f8_e3m4() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E3M4> + check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.09375, 3.125, 0x7F, 0xFF, 0.015625, -0.015625]> : tensor<10xf8E3M4> + func.return +} + +// ----- + func.func @constant_op_test_f8_e4m3b11_fnuz() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3B11FNUZ> check.expect_almost_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.101563, 3.25, 30.0, -30.0, 0.00012207, -0.00012207]> : tensor<10xf8E4M3B11FNUZ> @@ -104,6 +112,14 @@ func.func @constant_op_test_f8_e4m3b11_fnuz() { // ----- +func.func @constant_op_test_f8_e4m3() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3> + check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3> + func.return +} + +// ----- + func.func @constant_op_test_f8_e4m3_fn() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3FN> check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3FN> diff --git a/stablehlo/tests/interpret/dot_general.mlir b/stablehlo/tests/interpret/dot_general.mlir index a16087cd19..31f73a6e47 100644 --- a/stablehlo/tests/interpret/dot_general.mlir +++ b/stablehlo/tests/interpret/dot_general.mlir @@ -31,3 +31,26 @@ func.func @dot_general_op_test_empty_dims() { [[4, 0], [0, 4]]]]> : tensor<2x2x2x2xi64> func.return } + +// ----- + +func.func @add_op_test_f8E3M4() { + %0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf8E3M4> + %result = stablehlo.dot_general %0, %0, + contracting_dims = [0] x [0] + : (tensor<4xf8E3M4>, tensor<4xf8E3M4>) -> tensor + check.expect_almost_eq_const %result, dense<14.0> : tensor + func.return +} + +// ----- + +func.func @add_op_test_f8E4M3() { + %0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, 7.0]> : tensor<8xf8E4M3> + %result = stablehlo.dot_general %0, %0, + contracting_dims = [0] x [0] + : (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor + check.expect_almost_eq_const %result, dense<140.0> : tensor + func.return +} diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 8de4af11cb..2bbf1ddb19 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -2190,7 +2190,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2199,7 +2199,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2217,7 +2217,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = stablehlo.constant dense<7> : tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -2252,7 +2252,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2262,7 +2262,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2280,7 +2280,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2828,7 +2828,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4xi32>'}} %0 = "stablehlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -6055,7 +6055,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3xi32>'}} %0 = "stablehlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } @@ -6103,6 +6103,22 @@ func.func @convert(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: func @convert_f8e3m4 +func.func @convert_f8e3m4(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @convert_f8e4m3 +func.func @convert_f8e4m3(%arg0: tensor) -> tensor { + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @convert_f8e4m3fn func.func @convert_f8e4m3fn(%arg0: tensor) -> tensor { %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index cb4563ce9f..89c8e4d00d 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -380,7 +380,7 @@ func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %abs_neg = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>>) -> tensor<1x2x2x!quant.uniform:f32:0, {0.1:-30}>> func.return } @@ -388,7 +388,7 @@ func.func @negative_abs_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} %all_gather = "stablehlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> } : (tensor<2x4x!quant.uniform>) -> tensor<2x4x!quant.uniform> func.return } @@ -396,7 +396,7 @@ func.func @negative_all_gather_quantization(%arg0: tensor<2x4x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x4x!quant.uniform>'}} %all_to_all = "stablehlo.all_to_all"(%arg0) { split_dimension = 1 : i64, concat_dimension = 1 : i64, split_count = 2 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle} : (tensor<2x4x!quant.uniform>) -> tensor<2x4x!quant.uniform> func.return } @@ -404,7 +404,7 @@ func.func @negative_all_to_all_quantization(%arg0: tensor<2x4x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %atan2 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -412,7 +412,7 @@ func.func @negative_atan_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cbrt = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -420,7 +420,7 @@ func.func @negative_bitcast_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %ceil = "stablehlo.ceil"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -428,7 +428,7 @@ func.func @negative_ceil_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cholesky = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -437,7 +437,7 @@ func.func @negative_cholesky_quantization(%arg0: tensor<1x2x2x!quant.uniform>) -> tensor<1x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x!quant.uniform>'}} %0 = "stablehlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> func.return %0: tensor<1x!quant.uniform> } @@ -445,7 +445,7 @@ func.func @negative_clamp_quantization(%arg0: tensor<1x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %collective_permute = "stablehlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_handle = #stablehlo.channel_handle} : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -453,7 +453,7 @@ func.func @negative_collective_permute_quantization(%arg0: tensor<1x2x2x!quant.u // ----- func.func @negative_compare_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %compare = "stablehlo.compare"(%arg0, %arg1) { comparison_direction = #stablehlo, compare_type = #stablehlo } : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2xi1> func.return } @@ -461,7 +461,7 @@ func.func @negative_compare_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %concatenate = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<2x2x2x!quant.uniform> func.return } @@ -469,7 +469,7 @@ func.func @negative_concatenate_quantization(%arg0: tensor<1x2x2x!quant.uniform< // ----- func.func @negative_cosine_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %cosine = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -477,7 +477,7 @@ func.func @negative_cosine_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %divide = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -485,7 +485,7 @@ func.func @negative_divide_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor, %arg2: tensor) -> tensor<1x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = array} : (tensor<3x4x!quant.uniform>, tensor, tensor) -> tensor<1x4x!quant.uniform> func.return %0 : tensor<1x4x!quant.uniform> } @@ -493,7 +493,7 @@ func.func @negative_dynamic_slice_quantization(%arg0: tensor<3x4x!quant.uniform< // ----- func.func @negative_exponential_minus_one_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %exponential_minus_one = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -501,7 +501,7 @@ func.func @negative_exponential_minus_one_quantization(%arg0: tensor<1x2x2x!quan // ----- func.func @negative_exponential_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %exponential_minus_one = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -509,7 +509,7 @@ func.func @negative_exponential_quantization(%arg0: tensor<1x2x2x!quant.uniform< // ----- func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %floor = "stablehlo.floor"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -517,7 +517,7 @@ func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %is_finite = "stablehlo.is_finite"(%arg0) {} : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2xi1> func.return } @@ -525,7 +525,7 @@ func.func @negative_floor_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %log_plus_one = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -533,7 +533,7 @@ func.func @negative_log_plus_one_quantization(%arg0: tensor<1x2x2x!quant.uniform // ----- func.func @negative_logistic_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %logistic = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -541,7 +541,7 @@ func.func @negative_logistic_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %log = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -549,7 +549,7 @@ func.func @negative_log_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> tensor<4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x!quant.uniform>'}} %map = "stablehlo.map"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): "stablehlo.return"(%arg2) : (tensor>) -> () @@ -560,7 +560,7 @@ func.func @negative_map_quantization(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %maximum = "stablehlo.maximum"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -568,7 +568,7 @@ func.func @negative_maximum_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %minimum = "stablehlo.minimum"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -576,7 +576,7 @@ func.func @negative_minimum_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %multiply = "stablehlo.multiply"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -584,7 +584,7 @@ func.func @negative_multiply_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %negate = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -592,14 +592,14 @@ func.func @negative_negate_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values or token, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values or token, but got 'tensor<1x2x2x!quant.uniform>'}} %optimization_barrier = "stablehlo.optimization_barrier"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } // ----- func.func @negative_pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor>) -> tensor<2x4x7x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x3x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x3x!quant.uniform>'}} %pad = "stablehlo.pad"(%arg0, %arg1) { edge_padding_low = array, edge_padding_high = array, @@ -611,7 +611,7 @@ func.func @negative_pad_quantization(%arg0: tensor<1x2x3x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %power = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -619,7 +619,7 @@ func.func @negative_power_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor>) -> tensor> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x!quant.uniform>'}} %reduce = "stablehlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.add"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -633,7 +633,7 @@ func.func @reduce_quantization(%arg0: tensor<16x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %remainder = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -641,7 +641,7 @@ func.func @negative_remainder_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %rsqrt = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -649,7 +649,7 @@ func.func @negative_rsqrt_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %sine = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -657,7 +657,7 @@ func.func @negative_sine_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %sqrt = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -665,7 +665,7 @@ func.func @negative_sqrt_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %subtract = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -673,7 +673,7 @@ func.func @negative_subtract_quantization(%arg0: tensor<1x2x2x!quant.uniform>){ - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<1x2x2x!quant.uniform>'}} %tanh = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> func.return } @@ -681,7 +681,7 @@ func.func @negative_tanh_quantization(%arg0: tensor<1x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %mean: tensor<2x!quant.uniform>, %variance: tensor<2x!quant.uniform>, %grad_output: tensor<2x2x2x2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x2x2x2x!quant.uniform>) -> (tensor<2x2x2x2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) @@ -691,7 +691,7 @@ func.func @negative_batch_norm_grad_quantization(%input: tensor<2x2x2x2x!quant.u // ----- func.func @negative_batch_norm_inference_quantization(%input: tensor<4x256x!quant.uniform>, %scale: tensor<256x!quant.uniform>, %offset: tensor<256x!quant.uniform>, %mean: tensor<256x!quant.uniform>, %variance: tensor<256x!quant.uniform>) -> (tensor<4x256x!quant.uniform>) { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x256x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x256x!quant.uniform>'}} %0 = "stablehlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 @@ -701,7 +701,7 @@ func.func @negative_batch_norm_inference_quantization(%input: tensor<4x256x!quan // ----- func.func @negative_batch_norm_training_quantization(%input: tensor<2x2x2x2x!quant.uniform>, %scale: tensor<2x!quant.uniform>, %offset: tensor<2x!quant.uniform>) -> tensor<2x2x2x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x2x2x2x!quant.uniform>'}} %0:3 = "stablehlo.batch_norm_training" (%input, %scale, %offset) { epsilon = 0.001 : f32, feature_index = 1 : i64 @@ -712,7 +712,7 @@ func.func @negative_batch_norm_training_quantization(%input: tensor<2x2x2x2x!qua // ----- func.func @negative_dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x4x!quant.uniform>'}} %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0], @@ -727,7 +727,7 @@ func.func @negative_dot_general_quantization(%arg0: tensor<2x3x4x!quant.uniform< // ----- func.func @negative_dynamic_update_slice_pertensor_quantization(%operand: tensor<3x4x!quant.uniform>, %update: tensor<1x4x!quant.uniform>, %start_indices0: tensor, %start_indices1: tensor) -> tensor<3x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1) : (tensor<3x4x!quant.uniform>, tensor<1x4x!quant.uniform>, tensor, tensor) -> tensor<3x4x!quant.uniform> func.return %0 : tensor<3x4x!quant.uniform> } @@ -735,7 +735,7 @@ func.func @negative_dynamic_update_slice_pertensor_quantization(%operand: tensor // ----- func.func @negative_gather_quantization(%operand : tensor<*x!quant.uniform>, %start_indices : tensor<1x5x2xi32>) -> tensor<8x?x7x1x6x1x?x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<*x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<*x!quant.uniform>'}} %res = "stablehlo.gather"(%operand, %start_indices) { dimension_numbers = #stablehlo.gather< offset_dims = [0, 2, 3, 4, 5], @@ -752,7 +752,7 @@ func.func @negative_gather_quantization(%operand : tensor<*x!quant.uniform>) -> tensor<6x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<6x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<6x!quant.uniform>'}} %output = "stablehlo.reduce_precision"(%arg0) { exponent_bits = 5 : i32, mantissa_bits = 10 : i32 @@ -763,7 +763,7 @@ func.func @negative_reduce_precision_quantization(%arg0: tensor<6x!quant.uniform // ----- func.func @negative_reduce_scatter_quantization(%data: tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<4x16x!quant.uniform>'}} %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = stablehlo.add %arg2, %arg3 : tensor> @@ -778,7 +778,7 @@ func.func @negative_reduce_scatter_quantization(%data: tensor<4x16x!quant.unifor // ----- func.func @negative_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.uniform>, %arg1: tensor>) -> tensor<2x9x16x7x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x17x31x7x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x17x31x7x!quant.uniform>'}} %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ ^bb0(%arg2: tensor>, %arg3: tensor>): %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor>, tensor>) -> tensor> @@ -796,7 +796,7 @@ func.func @negative_reduce_window_quantization(%arg0: tensor<2x17x31x7x!quant.un // ----- func.func @negative_reverse_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x2x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x2x!quant.uniform>'}} %result = "stablehlo.reverse"(%operand) { dimensions = array } : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> @@ -806,7 +806,7 @@ func.func @negative_reverse_quantization(%operand: tensor<3x2x!quant.uniform>) -> tensor<2x!quant.uniform> { - // expected-error@+1 {{ operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} + // expected-error@+1 {{ operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} %0 = "stablehlo.round_nearest_afz"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } @@ -814,7 +814,7 @@ func.func @negative_round_afz(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { - // expected-error@+1 {{ operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} + // expected-error@+1 {{ operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x!quant.uniform>'}} %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> func.return %0 : tensor<2x!quant.uniform> } @@ -822,7 +822,7 @@ func.func @negative_round_even(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300x!quant.uniform>) -> tensor<200x100x300x!quant.uniform> { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<200x100x300x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<200x100x300x!quant.uniform>'}} %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.add"(%arg3, %arg4) : (tensor>, tensor>) -> tensor> @@ -841,7 +841,7 @@ func.func @negative_scatter_quantization(%arg0: tensor<200x100x300x!quant.unifor // ----- func.func @negative_select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3x!quant.uniform>, %arg2: tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> { - // expected-error@+1 {{operand #1 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x!quant.uniform>'}} + // expected-error@+1 {{operand #1 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<2x3x!quant.uniform>'}} %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3x!quant.uniform>, tensor<2x3x!quant.uniform>) -> tensor<2x3x!quant.uniform> func.return %0 : tensor<2x3x!quant.uniform> } @@ -849,7 +849,7 @@ func.func @negative_select_quantization(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3 // ----- func.func @negative_slice_quantization(%arg0: tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<3x4x!quant.uniform>'}} %0 = "stablehlo.slice"(%arg0) {start_indices = array, limit_indices = array, strides = array} : (tensor<3x4x!quant.uniform>) -> tensor<1x2x!quant.uniform> func.return %0 : tensor<1x2x!quant.uniform> } @@ -858,7 +858,7 @@ func.func @negative_slice_quantization(%arg0: tensor<3x4x!quant.uniform>, %input1: tensor<16x16x!quant.uniform>) { - // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be variadic of ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16x!quant.uniform>'}} %0:2 = "stablehlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor>, %arg1: tensor>, %arg2: tensor>, %arg3: tensor>): %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor @@ -870,7 +870,7 @@ func.func @negative_sort_quantization(%input0: tensor<16x16x!quant.uniform>, %arg1: tensor<10x23x23x64x!quant.uniform>, %arg2: tensor>) -> tensor<10x24x24x64x!quant.uniform> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<10x24x24x64x!quant.uniform>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 2/4/8/16/32/64-bit signless integer or 2/4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<10x24x24x64x!quant.uniform>'}} %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor>, %arg4: tensor>): %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor>, tensor>) -> tensor diff --git a/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/tests/ops_stablehlo_roundtrip.mlir index b6da68f953..ab11c086f1 100644 --- a/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ b/stablehlo/tests/ops_stablehlo_roundtrip.mlir @@ -183,7 +183,9 @@ func.func @test_constants() { %cst_4 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %cst_5 = arith.constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> %cst_6 = arith.constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32> + %cst_17 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E3M4> %cst_7 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3B11FNUZ> + %cst_16 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3> %cst_8 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3FN> %cst_9 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E4M3FNUZ> %cst_10 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf8E5M2> diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir new file mode 100644 index 0000000000..43b2d749fc --- /dev/null +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir @@ -0,0 +1,2851 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=1.6.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 +// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 +// RUN: diff %t.0 %t.1 +// RUN: stablehlo-translate --serialize --target=1.6.0 --strip-debuginfo %s > %t.2 +// RUN: diff %s.bc %t.2 +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode -debug-only=vhlo-bytecode %s 2>&1 | FileCheck --check-prefix=CHECK-WARN %s +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode %s | stablehlo-opt -debug-only=vhlo-bytecode 2>&1 | FileCheck --check-prefix=CHECK-WARN %s + +// CHECK-WARN-NOT: Not Implemented + +// ============ ATTRIBUTES ============ + +// CHECK-LABEL: "attr_comparison_direction_eq" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ne" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ge" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_gt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_le" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_lt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_notype" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_float" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_totalorder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_signed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_unsigned" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_original" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_dict" +// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} +func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { + return +} + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{#vhlo.string_v1<"bar"> = #vhlo.integer_v1<42 : i32>}> +func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= {bar = 42 : i32}, + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi_no_backend_config" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{}> +func.func @attr_custom_call_api_version_typed_ffi_no_backend_config(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +// CHECK-LABEL: "attr_fft_type_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_ifft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_rfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_irfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +// CHECK-LABEL: "attr_precision_config_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_high" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_highest" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_rng_algorithm_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_three_fry" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_philox" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_distribution_uniform" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_rng_distribution_normal" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +// CHECK-LABEL: "attr_transpose_no_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_adjoint" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +// CHECK-LABEL: "attr_type_extensions_bounds" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + + +// ============ DEFAULTS ============ + +// CHECK-LABEL: "default_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_gather_variadic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather_variadic(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) { + %0:2 = "stablehlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>, tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) + func.return %0#0, %0#1 : tensor<16x16xf32>, tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) + // CHECK-SAME: <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "default_all_to_all_variadic" +func.func @default_all_to_all_variadic(%arg0: tensor<4x16xf32>, %arg1: tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) { + %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>, tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) + func.return %0#0, %0#1 : tensor<16x4xf32>, tensor<20x4xf32> +} + +// CHECK-LABEL: "default_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "default_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_collective_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: "default_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "default_dot" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "default_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + func.return %arg0 : tensor +} + +// CHECK-LABEL: "default_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "default_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "default_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} + +// CHECK-LABEL: "default_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "default_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "default_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// ============ OPS ============ + +// CHECK-LABEL: "op_abs" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_add" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_after_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_all_reduce_with_promotable_types" +func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + + func.return %result : tensor +} + +// CHECK-LABEL: "default_all_reduce_variadic" +func.func @default_all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.all_reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> (tensor) + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor, tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "op_and" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_atan2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_batch_norm_grad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_batch_norm_inference" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +// CHECK-LABEL: "op_batch_norm_training" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_bitcast_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_case" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cbrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_ceil" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "op_clamp" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_count_leading_zeros" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "op_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_complex" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_int"> = #vhlo.integer_v1<1 : i64>, #vhlo.string_v1<"my_string"> = #vhlo.string_v1<"foo">}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<1 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target, + version = 1 : i32, + composite_attributes = { + my_string = "foo", + my_int = 1 : i64 + } + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_concatenate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_constant" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() <{ + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = array, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} + +// CHECK-LABEL: "op_cosine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_create_token" +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_cross_replica_sum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%[[ARG0]]) <{ + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call_empty_result_layout" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func public @op_custom_call_empty_result_layout(%arg0: tensor) -> tensor { + // %0 = "vhlo.custom_call_v1"(%arg0) <{>}> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"empty_output">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]>, + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + %0 = "stablehlo.custom_call"(%arg0) <{ + api_version = 2 : i32, + call_target_name = "empty_output", + has_side_effect = true, + operand_layouts = [dense<> : tensor<0xindex>], + result_layouts = [] + }> : (tensor) -> tuple<> + return %arg0 : tensor +} + +// CHECK-LABEL: "op_divide" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "op_dot" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = array, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "op_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<4xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>, tensor<4xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_iota" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%[[ARG0]]) <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = array + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_dynamic_update_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_exponential_minus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_exponential" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%[[ARG0]]) <{ + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "op_floor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + + func.return %arg0 : tensor +} + +// CHECK-LABEL: "op_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_get_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_get_tuple_element" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%[[ARG0]]) <{ + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_if" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG2]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_imag" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_iota" +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_is_finite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log_plus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_logistic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_map" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_maximum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_minimum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_multiply" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_negate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_not" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_optimization_barrier" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_or" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_popcnt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_power" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.reduce_v1"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_precision" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%[[ARG0]]) <{ + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK_lABEL: "op_reduce_with_promotable_types" +func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + // CHECK: "vhlo.reduce_v1"(%[[ARG0:.*]], %[[ARG1:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = array} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// CHECK-LABEL: "op_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" +func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + + +// CHECK-LABEL: "op_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} + +// CHECK-LABEL: "op_reduce_window_with_promotable_types" +func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = array, + window_strides = array } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// CHECK-LABEL: "op_remainder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_replica_id" +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_partition_id" +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%[[ARG0]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: "op_return" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reverse" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_rng_bit_generator" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%[[ARG0]]) <{ + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_rng" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + // CHECK: "vhlo.rng_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_afz" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_even" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_rsqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "op_scatter_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter_with_batching_dims(%arg0: tensor<10x200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<10x200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [1, 2], + input_batching_dims = [0], + scatter_dims_to_operand_dims = [1, 2], + scatter_indices_batching_dims = [0], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<10x200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<10x200x100x300xf32> + func.return %0 : tensor<10x200x100x300xf32> +} + +// CHECK_lABEL: "op_scatter_with_promotable_types" +func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// CHECK-LABEL: "op_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: %[[VAL:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK: "vhlo.return_v1"(%[[VAL]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> + func.return %0 : tensor<10x24x24x64xf64> +} + +// CHECK-LABEL: "op_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_set_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_shift_left" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_arithmetic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_logical" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sign" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%[[ARG0]]) <{ + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_sqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_subtract" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tan" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tan(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tan_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tan"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tanh" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_torch_index_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} + +// CHECK-LABEL: "op_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%[[ARG0]]) <{ + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = array + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "op_triangular_solve" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} + +// CHECK-LABEL: "op_unary_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%[[ARG0]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// CHECK-LABEL: "op_uniform_dequantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_uniform_quantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_while" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK-LABEL: "op_xor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ============ TYPES ============ + +// CHECK-LABEL: "type_i1" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E3M4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3B11FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_bf16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_complex_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_complex_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_dynamism_ranked" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_per_tensor_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_per_tensor_quantization(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_per_axis_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG0]]) : (!vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>, !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>) -> !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1> + %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> + func.return %0 : tensor<2x!quant.uniform> +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%[[ARG0]]) <{callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} + +// CHECK-LABEL: "type_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} + +// ============ DEPENDENCIES ============ + +func.func @composite_target(%arg0: tensor) -> tensor { + return %arg0: tensor +} diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_6_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..523a4159924b2f2b8fdbf892c45488cbcf21a78e GIT binary patch literal 18992 zcmcJ10gO~dy6!nur@HQ;)AT!Zn$FNE`ZXuRq~Q&72QF}d02v7F11_*Ygn)rg&rA?wWWg|xRs*#OsbhB=jjhJ<_Y?jTkS$yAL z=bY}-GvL1WZr(umU-f_W|Mk~jfBjW;=D@0d`@_1}`r%@Cf9Zez+nU18&KYxN{Ig+R zV!{-^H+htQ7f+*v7B9VR<(l;y?z?}_c<7PGo_Ok+=U#Z} zmDk>U``w>@@UxG9{`oI{`K#aj?z=zy=`Vl#A;QM9iEJ{P%4V=RY!S<|)ocUX%KBJ^ z?PPn|es+)@W=Gj^c9NZ9XV^J*fn8)@v&-yzc8%R&5k8(z{sL`B|d!tc!QkSDqevZczG9H)l zR7%duB;qm^PrHdkVq;q(m2%xwS|-w|G#GJwntxJ>G=NMdm74CxQ&Cw+r_+UmOyO@T zlW{X>odVeArW5Hnw4jSrI-N-aB|chbGWbH^YhJpetqt4^0VLQKw1pvQm)f;Ka|Sb=CDi zWgsC<+jM{FG6vp%3ry>(!?wu0P#5Fkg%`MZ?KSr+_t);<-5=a3%)LN@mM(a04l#9@sUyt&P90_H7*ofY zI>FRQrru!gAJr+QPBV3esk2O-W9mFp7nu5tsf$cqV(M$AzG3PzQ&*T9Vd{IPt}=H# zQ`eZf&eRR27*`Rl#&I>CyOS_h?oP&Vxthe)WUi)gcPd8A-Dwy!cV}SST+QTa4p;NJ zTEx{7uJT;1;O-)fm#cMLZQyD%ck>u7SKZuQfe~}HgR2U6S7Xpz?d0w{jGL>yTC>JV3lxjMqtQLc`0cQeMy-K`idce^oS?)G8O+}**{DehJeH{6BSFLQS%9Kzi_a0pjdxw^*Hb?)wm zL%4eYKCx89QsXQ&-cl1RHPKR&EH&9uQ!F*rQqwFo!%{OX7ZWqbQu8gf$Z`)ewZu|+ zORccnBXEVK)>-aRrZ!k=v*lvawp#9S%)jNHfI}>`!%`JX)hzcEqQFvnEVb8C`z*EJ zQU@$`&{Bsib=Y!GW8y9M3{yuf7xR0}a?imhmO5dnla_lP-muguOP#jV8B3kD)H%z= z)StK11xtNqsf(7nWT~$$_aef;axY=*mbzlOUt{2wi@>;Qxt9?Pmis;8!BRIY#e|9o z_bOsvxYrT>!bOBk5N?F4i9$^h?sx=-a3>%(ggc3=sluJY)imKw<7$R*XCgd=I|mUW z+(ify;pP!1!bNN>5$*;=i*UChV1(Pp)oS5Z5HrHviLepwUPO*iTZOwHkt1A$M4xaE zB6x&*81W-q1V&A`M-f6o?Gf&Agpg4CL@4D3Ng=bWD_H*l4WNKRO8Dgmwdgfi8Bp$>@5U8XvSbc%iGU?gdpW@cGE4`xi8 zY>A1Z5$9MWqYepmSR@i0O#&OUGAq)&g}0boy!93~ZxP2@xEQ$Mj3!AY1okpxeyMc1}yX+hbx zyLOXNxh-4B`LEO$@Tf5Lq*ftn8jP@r1h@Q5^Ihn*@VYY&IQR?92 z^%ossnp`^ETah#O#;E(8~^A%9j!dWbQoviX)zUia+g$_7UFh$hq^D@n;-)oBJ%^<%ru{zT4qZ zVMbv{lt-y=bbFu4Uv*^1f`z=pk*%+pyj9x?`*a(V$EZ$L2`OQIGzGDg#J5*O#i5$! z;OqiHPPW=KObmnw*tXW1bwp<;^$KHOcW@hXc$%jjdB6K6OvEZnI?<%r(n1BnX%GSp zmOGL!Pm|dFdJIWu<8P6A=qL4Ae3`)ETKJc&6rEJJlYuH~CLS>P+vL4>^r+u;><1sj z$lm)YWOJ=#_5+T52#$Ej^T+$}$HU|gO+Ik2GRX#;27|`S9r5AM_($Y@Or6Oe0p*xbNSQJskcB{}}5$krTqfmrV;hmgvB;ieZ5XVJEB=JB0ZY)0EcK z9o)PU5kJB7SkpNUx29>>dE5k~z%*INtiWUwhC^@&i#B1gCM@2BVfsUIi6$)Bgr%CW zmL@FSgf-i#ny^fhT(${wo3Pd#>XBe3k zMrMbRxnX2p7+KIj)|z~wv|`9V+%TbgakyPqquo!;rD0@Q7`e?yZZ~3ib2;)`-O`m| zWK|eh<0FD()*=(uW!C%1UkSM#3$BiA^u4t(YsB;BCO#|Vtvi~qElpUV2`e^XJst)x zC-Y_r0jj$bIx@F~v2rtZXBZo3#;Vfl;C2*A>eq(KQ%aWzb(%d14-Yk?BhBb8-&Umz zCvWcN@z5CWX~OPp!tQIr?r*{#kX8#4HDAFDGJ>Z0LEp()V|w0vh{qdFe%Lb)IYL9& zcsvV)AmbmA*6m(UrZ$+5LJF%q=dWxTA>IDM4r28&zoW8|%bSn$p+-kfNUIZG66GU| zuqUytg?oC+7ktoIo;RQ7OB;gE)NQ{lZ$7J2tiLdxt0Uj#&F90&3%+}9mVjRLrA%X4 z-h7G16d3=qFEwzNu{3YK!e0qS2T#0Om%{RFT7<8W2a)X7vCYir%A2qAr$cgY`Z{kG z#Bcf7rf-eKdGl@F9+G><51WCz9yS){&3E}T;h=iYbEbTBS&bQOdGn`yR~UX@TC`Oo z%7Xy-!0+nOWkze>{E+Vsb@ek3<4I#d-u#FT4dfATO?>RjnMq?_-u#5l74=O1^Sa#J zy!mOf+~*CsPJm5vU-&^K40nV13tI92AgI2S)+X}s-%TU4!TcqE6-G2fvZLdWcA2lF zwT0H?mXX~?VuSfB-a-Cp2#pC|iuPYi>tPaFewvv^dV~2J-bzwYYkgxg`)wm>-)bF@ zs^3X#BP|?ZbZjuc^~M<562A*+@q_vI(rWdVS_=)R5tV*QT#0|@ORva0)dtmgS;3EA|ep@!E-%r{**L8o2F~D z=e<2Tffr|PW`x<4@|`E4Y08NUuR`Abepq_v-zlhH=c zpkw@sOli#E6da}jd`xR#zs#F6ny{I=R|GyTva=>?m4bF|iWvH$K%W!z zY8qcQ_d36U{Gu6I6!fW+?GhHwws~Of-fm_ZUyxilv8)JmI>zU~a197%eRZHSYamUp zW8rMP0az&bo6k#|wV-!c{1iH~p()%NDAl>Tn}wa)2dtfzV6kZYJa6t`2>eEFssQzu z(dTW*%N5=^8a4OUk_>A7s#2$N6m8kf#BL=|BTTZL^rhK9spC% zxoLcuHxD+MeyD+AGdbLZ9Rb!!u`>I=!@GcwGJe`W6p`LzTGn&d2V_Uc#N%2PSg;kH zU`>0`N#J@Oyv?ZjWb_6=?N4tvI>nl{qtjY7a-+W=ong%z(pliMC~Sfq=^Sg`lFp;S zY-+&Rvh#Y+x-hE6XK0~4)@i)2=QG)uw2N9AYcywec6J&+#lVxzEA*0<#@NNT#(RLn zd+FD}^w7j?IXcHaD+H616f0`~_WXRKA@q8)5W|KyYDEn3Qram61K9<1WwB=7|h9~L` z6k(P#UenGFE#*nTrc(%h#*A0u7fo-Cs1b^FG3y8H^HKw-#}hy?r2d1dA=E0 z(m*j$pyA3t`WYY$)_Lgs`VaVVi;Z zT|O1+5*oJ#NX&Q=NXQM{0b&|YG|Th_Jz*1loP;R?%$vfm#utkQ@T7VW!sdx9Pv<58#$MbX}%N5sh^WK2kEAFxgyjoLX{%?|@j;yIF#XQgq}5Sza|q}iv?mDDn5}1ydTq>&4BJ7S(GGf8H$b?X#)EnDNK>>Q1-6A0 zj>C!2)A)8n#!xX1nY2?At{5p>5t5r~xOP<)c}cmCly>kWwL!FtEy zQ=GrsFbz&`3dCTrry+Dd4m1NtpE4R&$yc?`+$u2v76MbEE^AH`ant>&Jg*QG|@bh-4is5I6>~}%$pOf zrrb3tKy!|9*O)WnWS}~R9ml9bu^q}#Q$pClC~Rs7yK@va4VWK0WwNyq(K8yzwq|6e z)`tWnj8fj5WAWt;srkS*d0P=Kl8Y>FD+(Q^ziXuOAoiCudg^KJDX)7XZ<$7spwNg{ zG^7g6QmY%tmS$ufkWN~q6Zj-<+>tjoSR4-!>X(zvx<8C8Yi!D!TP-{}uW#Ai!1Oq7 z_UNmx(d9;T`A9=-N1$?{aeLmZSiG}=)dHzi#`An!wJ__3(U{5}&)9cERrw#|68BZQp_uy@|k*p?7>I>6q0 zPh*7;b|%2ye_vz85Oy}eKKxK)Jt6E|fPM6l#_;&K9_i-;?2}J4wk?ER2(V8-)mS-% zeHLI}e4(*BL)gUt`|?YTVHd0ScPYTIlAHDc~z~Q*_Ex?dGSm#ep_aG zHknDa&|7)!9SL;dq;qW@oo$J@nzvjjygitnn{Dq%r?Yq$uVebsmRv5^nq4s0o!8bj zuWeq(ta%F-&R#Hg;oOC}wzk%xHOW|PV9ToY3+Hvsowq>$%xx8}%`=_pL|bP^TP}@J zbfhvZsbqF~GMj00JCk@9v3+_BJ;$>x@l;DP-WiK06YWket+IIG5JPHLvpZXHszo9(KA%%KgLTim`Khf7KYM z_Kx(Ij9O`Uq*lSZxkleerH5#-x0gsybzrbqE9v6M&~SC2&{OR1*A4rIiS$-A8LpL! zm2D&a#adyg+P_l^4D}R;hfB4>&Y4E7vaQfv9xiMvRZ6wua1~5+0Y$Y^???|?&oqkN zL&O{^^_O~v3&l!rp)TL4)m_wjpm=wAU}S)bJ{A(%Rvg~$b>F+YQXDAv z6dD}rd$?NAV?p1jg)sC@mQ-tcXhS8h-cuaX!wO8ME_-%T%j(E*U#X;p%9WrL^x0Fb zpwUbtsCU<@#a`$dE|e>U-mgX&Ou*qeoFKRKuj7Y~+d z10z0nd$qqbbmwleCYYeRJumOtUhdbDBb9PrwKh<=You5iF7IjL^_H5s#oF#dsazTI zTsu56=&2a46)QuqMayCqd$wZ`y`{Sg{;Z+*;bLVwSplk@IMNLvOek1Gcli$OFVmER zQLgxOo7Z@I89`MmZ`(dx=&x=oVKqLJnFBFGL!{pFXWr!R% zS_Tm-Wo?&M>J6+`0uAC~DDcW)b(bePSRL%CXwMB6bvNFGRtJa6 z1LZwM-G6tnRx5kWE8a@1ln}|<%7GEA^TB={6a(c-uuukygETU7@R09ElnaCXBSVF% zHz8CMOrI9s|P<0*&S<+n`#+t0qGRLyR+@U9uZm=mq)k+m_ z0fzY7YF>;LF_4)gQ|zPMtrH~KUe7uWOV9|z#y{)8Mlu7bzGfJG)vDeyymU2OtRWPP ze>R49W8HW@r0i5Gg#;Vf&2LVc3f{g`9>PNIEo1HIU@zTe+#0Nuegxe1qu}t5NMdum zb)WrXCG}9T6fxp4f*QjzDBePB7)z|bT0`czx#&=NTcy-{b8G}zFNB8n>f&IzKuL^} zB3cb15%rbpYHrc6iY#2N6#M#w-cnx?IoMM&2HlQ*?nfYOO&&^Tr(*v$%m{YAW}VH@@b+4%ROqYizM0%0jsexX z8%Qsbf34i@9Wwl*!7WfqCqX1MP5^Ldz2XVhtG5PKrG zhYa6qI`w#mvYW69J8=>~Chcj$yv??WC` z>N+EaIQ0`=gVzZ4IxloE4D%Xg|6p6U!9P@vX+V3#80IMN($T-(W)$X~YwM_g;;q-| zBwVl1fw^93K9-MWdnfri*FU^B>SHbg4G5=mqdHi?wLz2KUQL2cOIPZr`k>~=2Z=TA z9=ouv3WK%k0J1MOGo22aqtG8iJ>-0OKeVWZ0TZFUlM18u)lsba7Ced{?AV&!ytixC zsEb1rk1BLQ(G&*=O>v;mG>B_zL#GYR9D5q^*!7J*2ziq7Zh>9{eV_@7URdiU^>dGh zc}F6@OedwF=-+I-DxIcuDRj>A>Nt#HWAV<_IE-m#XoXOhjcTDZFgUy$?T3(W3;o62 z2vB2iq#I{y67&vvy41L7`BhDBwTwvC$Mgyc)%^uBhEve-q$7Y%UpNQr01ed8xij1) zj-}zMe{c=6^x-zl2%mz9@12u%DRi=~AEieDRHbc<2%X47M=>1NLN(8sx*a8k!CHAI z4z*No4Bac;V#s5!0a|$Uje(fuKK+fWg{JxvtxLYQ>NgQ$;&$Sb`lZFM>kAE)amVq9 zemtOy-UUb({d*CWz1xy5RkcrUcDo`Nj-qexrcXtG;=D_mX3)Kj3S+Nz0oS8QmAVx@ z`Dm1og8VCDh#&FF+m~0Q z2TNW_-=uXB>!Rstkk8bl>28i&_*f=g*fp2#^Hd17jYjbn5%~|DGxq6|)<0$&^7N0{ zG{^wm)|2vfg|Uw%G+KXj(WrS(J{pfWMk91($NfpOaoG2$mp|)Cf}UDwAn=X<0ET$j zo_w-h(=iWlG&gw0<5ij;>uDZ64WhzGrSX`^t9uWa8V{4g+~5Hd@$@q(D&QHFcR4KM zi4k2A%QJoYFFyU(0H5O*W(6(g_)SIvnUQjx<4cW{JAJW{av(QS4g~xDb=j;{z6i?y zvTBysg;&n6uGeO-^7{;s+c(!SNG<8?%&!iy(7aLa?h8z46arKF0!xCY1bqQgX9O+0 z?PF-5fZ4edmpAP!?9c@|J>#*8?_K>+r~%^X=+I4|&`sQTzkY;+N3iXSiunE?-90Gu z7~ZCeoHvNw^q=@7W(@4c0dtrL-5zk^!}BNOpJryq;XMvBDwTZ47PC63@UAUqU5ilZ zJ+RhO;io8+sPI0ps~Y;C+bSXe;?MTl8r0k8H6}7e2P-m545VV#~`^iprnc ziEkoyMkXZvBfqHp)OL!Y>gTqzEmZx&PJG>H{VVxv^c0o9u@i597HR1Ct?j)1W>DQI ze`kw%&qw7(`K>MIU4Vt($={=N9;H9XKcaLFrP=aNHeX8g&$c+L8Go^PS5z*Of3?LK zFqX-`+2XV={N3iyM&%~?gDp-)Vkj_X^Jil6HW^{I9gja1r}8*v<9~2?Bcdzgb>(CP z|3kq9hSv-1n7m6)Mr-^61OHh8URIq-yr`VUYzbCO4#^piiN$UoaN=?%lXFmWs1}nW zay~jln>*wpW{cR%R9Zr%S7Pvd-qW-K+TSd^6pPE%Os->gtK-D7>DFwr9ppr0Y5VK# z)Zqqdy-RL}`j(EREp7}x*h;+JvKzJ}Q>mD|NA@v08WK z-Y09!b~N`nVlS6FVH<`eSIRvw98H$Xy);e-V^}Hok5x1}WR*Nmq)lGH_#|GEyoeFp30LBc$FGU? z4U?Ct#RKvRkq71X7#Mu_puCF0yqJ7QULznXuY2SMkrB!8widdJa5;|r@UR?DjCoZ=%ccXGK0^)%`& zY0MR^c??;U*r^mm+%h>b;k(7xWvmx-+-ZeUQAu%rJ^g;6SJ;zqL+QqU6q;5<0 z5i2TpShn34*#UbimaGwPuG|TpGd+fOdn~z^Dz`hR?6c&4sLgGS+!qq ztHfF(uX*IUM{W?otIk4`h=9UzLXJn`SxYfGS5Ck`q>SklFp=7{$w^+*$sUtho zM$NF)X9zhH#nte~ZE_BQ^MzbQbcv98AS>Xy6+*5i);b|KkebKkW+JQPR*!TO*eB!; zq7@-)M0X0ghv;4*_ffMoaz9lM2zij`At4VFJt8E2lOC7Hggj34gpenRz9Hl(qNjyC zL-eeW=ZKyc@&eJ%guF=fl8|2${YJ>kM6U?>J<+Q|UL$&4$QwjiL`EXoA>$&bj*rL* zL?=e%B%+feathI@5jl, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor {