Skip to content

Commit

Permalink
Fine-tune and add format comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Che-Yu Wu committed Mar 16, 2023
1 parent 761ea5e commit 21a0481
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class _ArchitectureInfo(object):
microarchitecture: str

def __str__(self):
return f"{self.type.name}-{self.architecture}-{self.microarchitecture}"
return f"{self.architecture}-{self.microarchitecture}"


class DeviceArchitecture(_ArchitectureInfo, Enum):
Expand Down Expand Up @@ -110,12 +110,12 @@ class DeviceSpec(object):
# Unique name of the device spec.
name: str

# Tags to describe the device.
tags: List[str]

# Device name. E.g., Pixel-6.
device_name: str

# Tags to describe the device spec.
tags: List[str]

# Host environment where the IREE runtime is running. For CPU device type,
# this is usually the same as the device that workloads are dispatched to.
# With a separate device, such as a GPU, however, the runtime and dispatched
Expand All @@ -136,13 +136,14 @@ def __str__(self):

@staticmethod
def build(id: str,
tags: Sequence[str],
device_name: str,
tags: Sequence[str],
host_environment: HostEnvironment,
architecture: DeviceArchitecture,
device_parameters: Optional[Sequence[str]] = None):
name = "{device_name}[{tags}]".format(device_name=device_name,
tags=",".join(tags))
tag_part = tags = ",".join(tags)
# Format: <device_name>[<tag>,...]
name = f"{device_name}[{tag_part}]"
device_parameters = [] if device_parameters is None else list(
device_parameters)
return DeviceSpec(id=id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def build(id: str,
tags: Sequence[str],
compile_targets: Sequence[CompileTarget],
extra_flags: Optional[Sequence[str]] = None):
name = ",".join(str(target) for target in compile_targets)
name += "[" + ",".join(tags) + "]"
target_part = ",".join(str(target) for target in compile_targets)
tag_part = ",".join(tags)
# Format: [<target_name>,...][<tag>,...]
name = f"[{target_part}][{tag_part}]"
extra_flags = [] if extra_flags is None else list(extra_flags)
return CompileConfig(id=id,
name=name,
Expand Down Expand Up @@ -113,8 +115,10 @@ def build(id: str,
loader: RuntimeLoader,
driver: RuntimeDriver,
extra_flags: Optional[Sequence[str]] = None):
name = f"{driver.name}({loader.name})".lower()
name += "[" + ",".join(tags) + "]"
runtime_part = f"{driver.name}({loader.name})".lower()
tag_part = ",".join(tags)
# Format: <driver>(<loader>)[<tag>,...]
name = f"{runtime_part}[{tag_part}]"
extra_flags = [] if extra_flags is None else list(extra_flags)
return ModuleExecutionConfig(id=id,
name=name,
Expand Down Expand Up @@ -231,6 +235,7 @@ def from_model(model: common_definitions.Model):
raise ValueError(f"Unsupported model source type: {model.source_type}.")

composite_id = unique_ids.hash_composite_id([model.id, config.id])
# Format: <model_name>(<import_config_name>)
name = f"{model}({config})"
return ImportedModel(composite_id=composite_id,
name=name,
Expand Down Expand Up @@ -264,7 +269,8 @@ def materialize_compile_flags(self):
def build(imported_model: ImportedModel, compile_config: CompileConfig):
composite_id = unique_ids.hash_composite_id(
[imported_model.composite_id, compile_config.id])
name = f"{imported_model}_{compile_config}"
# Format: <imported_model_name> <compile_config_name>
name = f"{imported_model} {compile_config}"
return ModuleGenerationConfig(
composite_id=composite_id,
name=name,
Expand Down Expand Up @@ -321,7 +327,8 @@ def build(module_generation_config: ModuleGenerationConfig,
module_generation_config.composite_id, module_execution_config.id,
target_device_spec.id, input_data.id
])
name = f"{module_generation_config}_{module_execution_config}_{input_data}_{target_device_spec}"
# Format: <module_generation_config_name> <module_execution_config_name> with <input_data_name> @ <target_device_spec_name>
name = f"{module_generation_config} {module_execution_config} with {input_data} @ {target_device_spec}"
run_flags = generate_run_flags(
imported_model=module_generation_config.imported_model,
input_data=input_data,
Expand Down

0 comments on commit 21a0481

Please sign in to comment.