Skip to content

Commit

Permalink
[MLIR][DLTI] Introduce DLTIQueryInterface and impl for DLTI attrs
Browse files Browse the repository at this point in the history
This new interface is supposed to capture the core functionality
of DLTI: querying for values at keys. As such this new interface
unifies the ability to query DLTI attributes in a single method:
query(). All existing DLTI interfaces exposing their own query methods
now 1) now extend this new interface and 2) provide a default
implementation for `query()`.

As DLTIQueryInterface::query() returns an attribute, it naturally
enables recursive queries on nested DLTI attrs. A utility function,
`dlti::query()`, implements the logic for nested lookups.

A new `#dlti.map` attribute is introduced to capture the most generic
form of a finite DLTI-mapping. One of the benefits is that it allows
for more easily encoding hierachical information that is suitably
queryable, i.e. by means of nested attributes.

In line with the above, `transform.dlti.query` is modified so as to
take an arbitrary number of keys and to perform a nested lookup
using the above utility function.
  • Loading branch information
rolfmorel committed Aug 20, 2024
1 parent 7452014 commit e2c6861
Show file tree
Hide file tree
Showing 12 changed files with 454 additions and 169 deletions.
13 changes: 4 additions & 9 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,10 @@ class DataLayoutEntryAttrStorage;
} // namespace mlir
namespace mlir {
namespace dlti {
/// Find the first DataLayoutSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);

/// Find the first TargetSystemSpec associated to `op`, via either the
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
/// query interface-implementing attrs, starting from attr obtained from `op`.
FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
bool emitError = false);
} // namespace dlti
} // namespace mlir

Expand Down
105 changes: 81 additions & 24 deletions mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_DLTI_DLTIATTRS_TD

include "mlir/Dialect/DLTI/DLTI.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"

class DLTIAttr<string name, list<Trait> traits = [],
Expand All @@ -20,13 +21,8 @@ class DLTIAttr<string name, list<Trait> traits = [],
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//

def DataLayoutEntryTrait
: NativeAttrTrait<"DataLayoutEntryInterface::Trait"> {
let cppNamespace = "::mlir";
}

def DLTI_DataLayoutEntryAttr :
DLTIAttr<"DataLayoutEntry", [DataLayoutEntryTrait]> {
DLTIAttr<"DataLayoutEntry", [DataLayoutEntryInterface]> {
let summary = "An attribute to represent an entry of a data layout specification.";
let description = [{
A data layout entry attribute is a key-value pair where the key is a type or
Expand All @@ -53,13 +49,9 @@ def DLTI_DataLayoutEntryAttr :
//===----------------------------------------------------------------------===//
// DataLayoutSpecAttr
//===----------------------------------------------------------------------===//
def DataLayoutSpecTrait
: NativeAttrTrait<"DataLayoutSpecInterface::Trait"> {
let cppNamespace = "::mlir";
}

def DLTI_DataLayoutSpecAttr :
DLTIAttr<"DataLayoutSpec", [DataLayoutSpecTrait]> {
DLTIAttr<"DataLayoutSpec", [DataLayoutSpecInterface]> {
let summary = "An attribute to represent a data layout specification.";
let description = [{
A data layout specification is a list of entries that specify (partial) data
Expand All @@ -78,7 +70,7 @@ def DLTI_DataLayoutSpecAttr :
/// same key as the newer entries if the entries are compatible. Returns null
/// if the specifications are not compatible.
DataLayoutSpecAttr combineWith(ArrayRef<DataLayoutSpecInterface> specs) const;

/// Returns the endiannes identifier.
StringAttr getEndiannessIdentifier(MLIRContext *context) const;

Expand All @@ -93,20 +85,78 @@ def DLTI_DataLayoutSpecAttr :

/// Returns the stack alignment identifier.
StringAttr getStackAlignmentIdentifier(MLIRContext *context) const;

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
}
}];
}

def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
let summary = "A mapping of DLTI-information by way of key-value pairs";
let description = [{
A Data Layout and Target Information map is a list of entries effectively
encoding a dictionary, mapping DLTI-related keys to DLTI-related values.

This attribute's main purpose is to facilate querying IR for arbitrary
key-value associations that encode DLTI. Facility functions exist to perform
recursive lookups on nested DLTI-map/query interface-implementing
attributes.

Consider the following flat encoding of a single-key dictionary
```
#dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
```
versus nested maps, which make it possible to obtain sub-dictionaries of
related information (with the following example making use of other
attributes that also implement the `DLTIQueryInterface`):
```
#dlti.target_system_spec<"CPU":
#dlti.target_device_spec<#dlti.dl_entry<"cache",
#dlti.map<#dlti.dl_entry<"L1",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
#dlti.dl_entry<"L1d",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
```

With the flat encoding, the implied structure of the key is ignored, that is
the only successful query (as expressed in the Transform Dialect) is:
`transform.dlti.query ["CPU::cache::L1::size_in_bytes"] at %op`,
where `%op` is a handle to an operation which associates the flat-encoding
`#dlti.map` attribute.

For querying nested dictionaries, the relevant keys need to be separately
provided. That is, if `%op` is an handle to an op which has the nesting
`#dlti.target_system_spec`-attribute from above attached, then
`transform.dlti.query ["CPU","cache","L1","size_in_bytes"] at %op` gives
back the first leaf value contained. To access the other leaf, we need to do
`transform.dlti.query ["CPU","cache","L1d","size_in_bytes"] at %op`.
```
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
);
let mnemonic = "map";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
for (DataLayoutEntryInterface entry : getEntries())
if (entry.getKey() == key)
return entry.getValue();
return ::mlir::failure();
}
}];
}

//===----------------------------------------------------------------------===//
// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//

def TargetSystemSpecTrait
: NativeAttrTrait<"TargetSystemSpecInterface::Trait"> {
let cppNamespace = "::mlir";
}

def DLTI_TargetSystemSpecAttr :
DLTIAttr<"TargetSystemSpec", [TargetSystemSpecTrait]> {
DLTIAttr<"TargetSystemSpec", [TargetSystemSpecInterface]> {
let summary = "An attribute to represent target system specification.";
let description = [{
A system specification describes the overall system containing
Expand Down Expand Up @@ -136,6 +186,11 @@ def DLTI_TargetSystemSpecAttr :
std::optional<TargetDeviceSpecInterface>
getDeviceSpecForDeviceID(
TargetSystemSpecInterface::DeviceID deviceID);

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
}
}];
let extraClassDefinition = [{
std::optional<TargetDeviceSpecInterface>
Expand All @@ -154,13 +209,8 @@ def DLTI_TargetSystemSpecAttr :
// TargetDeviceSpecAttr
//===----------------------------------------------------------------------===//

def TargetDeviceSpecTrait
: NativeAttrTrait<"TargetDeviceSpecInterface::Trait"> {
let cppNamespace = "::mlir";
}

def DLTI_TargetDeviceSpecAttr :
DLTIAttr<"TargetDeviceSpec", [TargetDeviceSpecTrait]> {
DLTIAttr<"TargetDeviceSpec", [TargetDeviceSpecInterface]> {
let summary = "An attribute to represent target device specification.";
let description = [{
Each device specification describes a single device and its
Expand All @@ -179,6 +229,13 @@ def DLTI_TargetDeviceSpecAttr :
let mnemonic = "target_device_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";

let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
}
}];
}

#endif // MLIR_DIALECT_DLTI_DLTIATTRS_TD
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/DLTI/DLTIBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def DLTI_Dialect : Dialect {
}];

let extraClassDeclaration = [{
// Top level attribute name.
// Top-level attribute name for arbitrary description.
constexpr const static ::llvm::StringLiteral
kMapAttrName = "dlti.map";

// Top-level attribute name for data layout description.
constexpr const static ::llvm::StringLiteral
kDataLayoutAttrName = "dlti.dl_spec";

// Top level attribute name for target system description
// Top-level attribute name for target system description.
constexpr const static ::llvm::StringLiteral
kTargetSystemDescAttrName = "dlti.target_system_spec";

// Top-level attribute name for target device description.
constexpr const static ::llvm::StringLiteral
kTargetDeviceDescAttrName = "dlti.target_device_spec";

Expand Down
40 changes: 24 additions & 16 deletions mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,40 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
IR by way of the DLTI dialect. A lookup is performed for the given `key`
at the `target` op, with the DLTI dialect determining which interfaces and
attributes are consulted - first checking `target` and then its ancestors.
IR by way of the DLTI dialect.

When only `key` is provided, the lookup occurs with respect to the data
layout specification of DLTI. When `device` is provided, the lookup occurs
with respect to DLTI's target device specifications associated to a DLTI
system device specification.
A lookup is performed for the given `keys` at `target` op - or its closest
interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
returns an attribute for a key. If more than one key is provided, the lookup
continues recursively, now on the returned attributes, with the condition
that these implement the above interface. For example if the payload IR is

```
module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
#dlti.map<#dlti.dl_entry<"B", 42: int>>>} {
func.func private @f()
}
```
and we have that `%func` is a Tranform handle to op `@f`, then
`transform.dlti.query ["A", "B"] at %func` returns 42 as a param and
`transform.dlti.query ["A"] at %func` returns the `#dlti.map` attribute
containing just the key "B" and its value. Using `["B"]` or `["A","C"]` as
`keys` will yield an error.

#### Return modes

When succesful, the result, `associated_attr`, associates one attribute as a
param for each op in `target`'s payload.
When successful, the result, `associated_attr`, associates one attribute as
a param for each op in `target`'s payload.

If the lookup fails - as DLTI specifications or entries with the right
names are missing (i.e. the values of `device` and `key`) - a definite
failure is returned.
If the lookup fails - as no DLTI attributes/interfaces are found or entries
with the right names are missing - a silenceable failure is returned.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrAttr>:$device,
StrAttr:$key);
StrArrayAttr:$keys);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
"(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
"functional-type(operands, results)";
"$keys `at` $target attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace mlir {
class DataLayout;
class DataLayoutEntryInterface;
class DLTIQueryInterface;
class TargetDeviceSpecInterface;
class TargetSystemSpecInterface;
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
Expand Down
60 changes: 56 additions & 4 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,29 @@ include "mlir/IR/OpBase.td"
// Attribute interfaces
//===----------------------------------------------------------------------===//

def DLTIQueryInterface : AttrInterface<"DLTIQueryInterface"> {
let cppNamespace = "::mlir";

let description = [{
Attribute interface exposing querying-mechanism for key-value associations.

The central feature of DLTI attributes is to allow looking up values at
keys. This interface represent the core functionality to do so - as such
most DLTI attributes should be implementing this interface.

Note that as the `query` method returns an attribute, this attribute can
be recursively queried when it also implements this interface.
}];
let methods = [
InterfaceMethod<
/*description=*/"Returns the attribute associated with the key.",
/*retTy=*/"::mlir::FailureOr<::mlir::Attribute>",
/*methodName=*/"query",
/*args=*/(ins "::mlir::DataLayoutEntryKey":$key)
>
];
}

def DataLayoutEntryInterface : AttrInterface<"DataLayoutEntryInterface"> {
let cppNamespace = "::mlir";

Expand Down Expand Up @@ -68,7 +91,7 @@ def DataLayoutEntryInterface : AttrInterface<"DataLayoutEntryInterface"> {
}];
}

def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";

let description = [{
Expand Down Expand Up @@ -173,7 +196,7 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
/*defaultImplementation=*/[{
return ::mlir::detail::verifyDataLayoutSpec($_attr, loc);
}]
>,
>
];

let extraClassDeclaration = [{
Expand All @@ -184,6 +207,15 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
return getSpecForType(TypeID::get<Ty>());
}

/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
inline ::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
for (DataLayoutEntryInterface entry : getEntries())
if (entry.getKey() == key)
return entry.getValue();
return ::mlir::failure();
}

/// Populates the given maps with lists of entries grouped by the type or
/// identifier they are associated with. Users are not expected to call this
/// method directly.
Expand All @@ -194,7 +226,7 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
}];
}

def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";

let description = [{
Expand Down Expand Up @@ -239,9 +271,20 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
/*defaultImplementation=*/[{ return ::mlir::success(); }]
>
];

let extraClassDeclaration = [{
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<StringAttr>(key))
if (DataLayoutEntryInterface spec = getSpecForIdentifier(strKey))
return spec.getValue();
return ::mlir::failure();
}
}];
}

def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";

let description = [{
Expand Down Expand Up @@ -287,6 +330,15 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {

let extraClassDeclaration = [{
using DeviceID = StringAttr;

/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<::mlir::StringAttr>(key))
if (auto deviceSpec = getDeviceSpecForDeviceID(strKey))
return *deviceSpec;
return ::mlir::failure();
}
}];
}

Expand Down
Loading

0 comments on commit e2c6861

Please sign in to comment.