Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][DLTI] Enable types as keys in DLTI-query utils #105995

Merged
merged 1 commit into from
Aug 27, 2024

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Aug 25, 2024

Enable support for query functions - including transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 25, 2024

@llvm/pr-subscribers-mlir-dlti

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Enable support for query functions - include transform.dlti.query - to take types as keys. As the data layout specific attributes already supported types as keys, this change enables querying such attributes in the expected way.


Full diff: https://github.com/llvm/llvm-project/pull/105995.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/DLTI/DLTI.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td (+5-4)
  • (modified) mlir/lib/Dialect/DLTI/DLTI.cpp (+22-4)
  • (modified) mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp (+10-1)
  • (modified) mlir/test/Dialect/DLTI/invalid.mlir (+8)
  • (modified) mlir/test/Dialect/DLTI/query.mlir (+88)
  • (modified) mlir/test/Dialect/DLTI/valid.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index a97eb523cb0631..f268fea340a6fb 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -26,7 +26,7 @@ namespace mlir {
 namespace dlti {
 /// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
 /// query interface-implementing attrs, starting from attr obtained from `op`.
-FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
+FailureOr<Attribute> query(Operation *op, ArrayRef<DataLayoutEntryKey> keys,
                            bool emitError = false);
 } // namespace dlti
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
index 1b1bebfaab4e38..f25bb383912d45 100644
--- a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
+++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
@@ -26,9 +26,10 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
 
     A lookup is performed for the given `keys` at `target` op - or its closest
     interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
-    returns an attribute for a key. If more than one key is provided, the lookup
-    continues recursively, now on the returned attributes, with the condition
-    that these implement the above interface. For example if the payload IR is
+    returns an attribute for a key. Each key should be either a (quoted) string
+    or a type. If more than one key is provided, the lookup continues
+    recursively, now on the returned attributes, with the condition that these
+    implement the above interface. For example if the payload IR is
 
     ```
     module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
@@ -52,7 +53,7 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       StrArrayAttr:$keys);
+                       ArrayAttr:$keys);
   let results = (outs TransformParamTypeInterface:$associated_attr);
   let assemblyFormat =
       "$keys `at` $target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 7f8e11a1b73341..58f8799b0714d7 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -424,8 +424,8 @@ getClosestQueryable(Operation *op) {
   return std::pair(queryable, op);
 }
 
-FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
-                                 bool emitError) {
+FailureOr<Attribute>
+dlti::query(Operation *op, ArrayRef<DataLayoutEntryKey> keys, bool emitError) {
   auto [queryable, queryOp] = getClosestQueryable(op);
   Operation *reportOp = (queryOp ? queryOp : op);
 
@@ -438,6 +438,17 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
     return failure();
   }
 
+  auto keyToStr = [](DataLayoutEntryKey key) -> std::string {
+    if (auto strKey = llvm::dyn_cast<StringAttr>(key))
+      return "\"" + std::string(strKey.getValue()) + "\"";
+    if (auto typeKey = llvm::dyn_cast<Type>(key)) {
+      std::string buf;
+      llvm::raw_string_ostream(buf) << typeKey;
+      return buf;
+    }
+    llvm_unreachable("DataLayoutEntryKey was not `StringAttr` or `Type`");
+  };
+
   Attribute currentAttr = queryable;
   for (auto &&[idx, key] : llvm::enumerate(keys)) {
     if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
@@ -446,17 +457,24 @@ FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
         if (emitError) {
           auto diag = op->emitError() << "target op of failed DLTI query";
           diag.attachNote(reportOp->getLoc())
-              << "key " << key << " has no DLTI-mapping per attr: " << map;
+              << "key " << keyToStr(key)
+              << " has no DLTI-mapping per attr: " << map;
         }
         return failure();
       }
       currentAttr = *maybeAttr;
     } else {
       if (emitError) {
+        std::string commaSeparatedKeys;
+        llvm::interleave(
+            keys.take_front(idx), // All prior keys.
+            [&](auto key) { commaSeparatedKeys += keyToStr(key); },
+            [&]() { commaSeparatedKeys += ","; });
+
         auto diag = op->emitError() << "target op of failed DLTI query";
         diag.attachNote(reportOp->getLoc())
             << "got non-DLTI-queryable attribute upon looking up keys ["
-            << keys.take_front(idx) << "] at op";
+            << commaSeparatedKeys << "] at op";
       }
       return failure();
     }
diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
index 90aef82bddff00..2f171a8375b46d 100644
--- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
+++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
@@ -33,7 +33,16 @@ void transform::QueryOp::getEffects(
 DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
     transform::TransformRewriter &rewriter, Operation *target,
     transform::ApplyToEachResultList &results, TransformState &state) {
-  auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
+  auto keys = SmallVector<DataLayoutEntryKey>();
+  for (Attribute key : getKeys()) {
+    if (auto strKey = dyn_cast<StringAttr>(key))
+      keys.push_back(strKey);
+    else if (auto typeKey = dyn_cast<TypeAttr>(key))
+      keys.push_back(typeKey.getValue());
+    else
+      return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: "
+                                 "only StringAttr and TypeAttr are allowed");
+  }
 
   FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true);
 
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 05f919fa256713..4b04f0195ef823 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -33,6 +33,14 @@
 
 // -----
 
+// expected-error@below {{repeated layout entry key: 'i32'}}
+"test.unknown_op"() { test.unknown_attr = #dlti.map<
+  #dlti.dl_entry<i32, 42>,
+  #dlti.dl_entry<i32, 42>
+>} : () -> ()
+
+// -----
+
 // expected-error@below {{repeated layout entry key: 'i32'}}
 "test.unknown_op"() { test.unknown_attr = #dlti.dl_spec<
   #dlti.dl_entry<i32, 42>,
diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index 10e91afd2ca7e1..e449c2c44bc617 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -17,6 +17,60 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-remark @below {{associated attr 42 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, 42 : i32>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %param = transform.dlti.query [i32] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{associated attr 32 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<i32, #dlti.map<#dlti.dl_entry<"width_in_bits", 32 : i32>>>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %param = transform.dlti.query [i32,"width_in_bits"] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-remark @below {{width in bits of i32 = 32 : i64}}
+// expected-remark @below {{width in bits of f64 = 64 : i64}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>, #dlti.dl_entry<f64, 64>>>>} {
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+    %i32bits = transform.dlti.query ["width_in_bits",i32] at %module : (!transform.any_op) -> !transform.any_param
+    %f64bits  = transform.dlti.query ["width_in_bits",f64] at %module : (!transform.any_op) -> !transform.any_param
+    transform.debug.emit_param_as_remark %i32bits, "width in bits of i32 =" at %module : !transform.any_param, !transform.any_op
+    transform.debug.emit_param_as_remark %f64bits, "width in bits of f64 =" at %module : !transform.any_param, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-remark @below {{associated attr 42 : i32}}
 module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
   func.func private @f()
@@ -336,6 +390,23 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys [i32]}}
+module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<i32, 32 : i32>>} {
+  // expected-error @below {{target op of failed DLTI query}}
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{'transform.dlti.query' op failed to apply}}
+    %param = transform.dlti.query [i32,"width_in_bits"] at %func : (!transform.any_op) -> !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
 module {
   // expected-error @below {{target op of failed DLTI query}}
   // expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
@@ -353,6 +424,23 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// expected-note @below {{key i64 has no DLTI-mapping per attr: #dlti.map<#dlti.dl_entry<i32, 32 : i64>>}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"width_in_bits", #dlti.map<#dlti.dl_entry<i32, 32>>>>} {
+  // expected-error @below {{target op of failed DLTI query}}
+  func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{'transform.dlti.query' op failed to apply}}
+    %param = transform.dlti.query ["width_in_bits",i64] at %func : (!transform.any_op) -> !transform.any_param
+    transform.yield
+  }
+}
+
+// -----
+
 module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
   func.func private @f()
 }
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index 4133eac5424ce8..023caf6ac5a05f 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -206,3 +206,18 @@ module attributes {
     "GPU": #dlti.target_device_spec<
       #dlti.dl_entry<"L1_cache_size_in_bytes", "128">>
   >} {}
+
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<"dlti.unknown_id", 42>> } : () -> ()
+
+// -----
+
+// CHECK: "test.op_with_dlti_map"() ({
+// CHECK: }) {dlti.map = #dlti.map<#dlti.dl_entry<i32, 42 : i64>>}
+"test.op_with_dlti_map"() ({
+}) { dlti.map = #dlti.map<#dlti.dl_entry<i32, 42>> } : () -> ()
\ No newline at end of file

@rolfmorel
Copy link
Contributor Author

Hi @rengolin, @ftynse, @banach-space, @joker-eph, @adam-smnk & @Dinistro, as I think you are interested in making progress on making use of DLTI's target descriptors, I thought to ping you on this.

If any of you could help with review, that would be appreciated - thanks!

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only use-case a list of types for some parent property?

Is this so much better than just having a string "f32" and then parsing it into a type?

Honest questions, I can't follow what you're trying to reach, here.

mlir/lib/Dialect/DLTI/DLTI.cpp Outdated Show resolved Hide resolved
mlir/test/Dialect/DLTI/query.mlir Outdated Show resolved Hide resolved
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Aug 25, 2024

To me there are two main arguments for wanting this:

  1. Data layout entry keys are either StringAttrs or Types, for over three years. Given that, why not allow the query mechanism to work with any key it could encounter. There are a number of upstream test cases (from before I touched DLTI) where a type is a key of a #dlti.dl_entry (e.g. do a search for #dlti.dl_entry<i or #dlti.dl_entry<f).

  2. The DataLayout class of DLTI has functions like Attribute getEndianness() and uint64_t getStackAlignment() that get mirrored on DataLayoutSpecAttr as lookups on lists of #dlti.dl_entrys with the corresponding keys, in this case "dlti.endianness" and "dlti.stack_alignment". Missing from DataLayoutSpecAttr are lookups for the DataLayout functions like llvm::TypeSize getTypeSize(Type t) and uint64_t getTypePreferredAlignment(Type t). These could be implemented as nested lookups on DLTI attributes, with the first key the type and the second key the property's name (e.g. "dlti.type_size" and "dlti.preferred_alignment") or vice versa.
    In general, I would like for properties of types to be easily encodable in the DLTI attributes, e.g. as #dlti.map<#dlti.dl_entry<#my_dialect.special_vec<16xf32>, #dlti.map<#dlti.dl_entry<"strided", 4: i64>>>>, to facilitate developing cost models that might want to depend on particular properties of particular types, especially while they are under development.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine, +1 for completeness

mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/DLTI/DLTI.cpp Show resolved Hide resolved
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned about the amount of string manipulation involved here.

StringAttr and Type are uniqued pointers in the context and we should be able to take advantage of this here for simple pointer-based keys here.

@rolfmorel
Copy link
Contributor Author

rolfmorel commented Aug 26, 2024

I think there's a misunderstanding here, @joker-eph : only the erroring-out path has any string manipulation, and only when requested, i.e. all string manipulation is guarded by emitError==true checks.

The key comparison itself is delegated to the query() method of attributes that implement DLTIQueryInterface, e.g. see the following from #104595:

if (entry.getKey() == key)
As this is == of PointerUnion<StringAttr,Type>, this does make use of (void *) pointer comparisons for equality checks. Note as well that currently all DLTI attributes have essentially this same query() implementation.

If you are concerned about the amount of string manipulation on the error-reporting path, I am happy to discuss that. If there's consensus in favour of less string manipulation even if that means less informative error messages, this is something I could work with on.

Enable support for query functions - including transform.dlti.query - to
take types as keys. As the data layout specific attributes already
supported types as keys, this change enables querying such attributes
in the expected way.
@rolfmorel
Copy link
Contributor Author

Thanks for the review, @rengolin, @adam-smnk and @joker-eph !

I just now did a rebase, a squash and a check-all. I think this PR is good to go.

If somebody could help with merging, that would be appreciated!

@joker-eph joker-eph merged commit 063e0bd into llvm:main Aug 27, 2024
8 checks passed
5c4lar pushed a commit to 5c4lar/llvm-project that referenced this pull request Aug 29, 2024
Enable support for query functions - including transform.dlti.query - to
take types as keys. As the data layout specific attributes already
supported types as keys, this change enables querying such attributes in
the expected way.
dmpolukhin pushed a commit to dmpolukhin/llvm-project that referenced this pull request Sep 2, 2024
Enable support for query functions - including transform.dlti.query - to
take types as keys. As the data layout specific attributes already
supported types as keys, this change enables querying such attributes in
the expected way.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants