diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index d6f8a3e03..88d9ef93f 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -40,13 +40,13 @@ jobs:
chmod +x bazel
sudo mv bazel /usr/local/bin/bazel
sudo apt install clang-9 patchelf
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r leaderboard/requirements.txt -r tests/requirements.txt
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
if: matrix.os == 'ubuntu-latest'
- name: Install dependencies (macOS)
run: |
brew install bazelisk zlib
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r leaderboard/requirements.txt -r tests/requirements.txt
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
env:
LDFLAGS: -L/usr/local/opt/zlib/lib
CPPFLAGS: -I/usr/local/opt/zlib/include
@@ -84,13 +84,13 @@ jobs:
chmod +x bazel
sudo mv bazel /usr/local/bin/bazel
sudo apt install clang-9 patchelf
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r leaderboard/requirements.txt -r tests/requirements.txt
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
if: matrix.os == 'ubuntu-latest'
- name: Install dependencies (macos)
run: |
brew install bazelisk zlib
- python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r leaderboard/requirements.txt -r tests/requirements.txt
+ python -m pip install -r compiler_gym/requirements.txt -r examples/requirements.txt -r tests/requirements.txt
env:
LDFLAGS: -L/usr/local/opt/zlib/lib
CPPFLAGS: -I/usr/local/opt/zlib/include
@@ -110,6 +110,21 @@ jobs:
CC: clang
CXX: clang++
BAZEL_BUILD_OPTS: --config=ci
+ if: matrix.os == 'macos-latest'
+
+ - name: Test with coverage
+ run: make install-test-cov
+ env:
+ CC: clang
+ CXX: clang++
+ BAZEL_BUILD_OPTS: --config=ci
+ if: matrix.os == 'ubuntu-latest'
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v1
+ with:
+ files: ./coverage.xml
+ if: matrix.os == 'ubuntu-latest'
- name: Uninstall
run: make purge
diff --git a/.gitignore b/.gitignore
index 82f3e4cfa..3d24ca5d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,12 +1,13 @@
+__pycache__
.DS_Store
.env
-/*.egg-info
/.act
/.clwb
/.vscode
+/*.egg-info
/bazel-*
/build
+/coverage.xml
/dist
/node_modules
/package-lock.json
-__pycache__
diff --git a/BUILD.bazel b/BUILD.bazel
index e71d72af8..4558ba5cc 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -14,8 +14,8 @@ exports_files([
py_library(
name = "CompilerGym",
data = [
- "//compiler_gym/third_party/cBench:benchmarks_list",
- "//compiler_gym/third_party/cBench:crc32",
+ "//compiler_gym/third_party/cbench:benchmarks_list",
+ "//compiler_gym/third_party/cbench:crc32",
],
deps = [
"//compiler_gym",
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 13032c32b..73776f516 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,90 @@
+## Release 0.1.8 (2021-04-30)
+
+This release introduces some significant changes to the way that benchmarks are
+managed, introducing a new dataset API. This enabled us to add support for
+millions of new benchmarks and a more efficient implementation for the LLVM
+environment, but this will require some migrating of old code to the new
+interfaces (see "Migration Checklist" below). Some of the key changes of this
+release are:
+
+- **[Core API change]** We have added a Python
+ [Benchmark](https://facebookresearch.github.io/CompilerGym/compiler_gym/datasets.html#compiler_gym.datasets.Benchmark)
+ class ([#190](https://github.com/facebookresearch/CompilerGym/pull/190)). The
+ `env.benchmark` attribute is now an instance of this class rather than a
+ string ([#222](https://github.com/facebookresearch/CompilerGym/pull/222)).
+- **[Core behavior change]** Environments will no longer select benchmarks
+ randomly. Now `env.reset()` will now always select the last-used benchmark,
+ unless the `benchmark` argument is provided or `env.benchmark` has been set.
+ If no benchmark is specified, a default is used.
+- **[API deprecations]** We have added a new
+ [Dataset](https://facebookresearch.github.io/CompilerGym/compiler_gym/datasets.html#compiler_gym.datasets.Dataset)
+ class hierarchy
+ ([#191](https://github.com/facebookresearch/CompilerGym/pull/191),
+ [#192](https://github.com/facebookresearch/CompilerGym/pull/192)). All
+ datasets are now available without needing to be downloaded first, and a new
+ [Datasets](https://facebookresearch.github.io/CompilerGym/compiler_gym/datasets.html#compiler_gym.datasets.Datasets)
+ class can be used to iterate over them
+ ([#200](https://github.com/facebookresearch/CompilerGym/pull/200)). We have
+ deprecated the old dataset management operations, the
+ `compiler_gym.bin.datasets` script, and removed the `--dataset` and
+ `--ls_benchmark` flags from the command line tools.
+- **[RPC interface change]** The `StartSession` RPC endpoint now accepts a list
+ of initial observations to compute. This removes the need for an immediate
+ call to `Step`, reducing environment reset time by 15-21%
+ ([#189](https://github.com/facebookresearch/CompilerGym/pull/189)).
+- [LLVM] We have added several new datasets of benchmarks, including the Csmith
+ and llvm-stress program generators
+ ([#207](https://github.com/facebookresearch/CompilerGym/pull/207)), a dataset
+ of OpenCL kernels
+ ([#208](https://github.com/facebookresearch/CompilerGym/pull/208)), and a
+ dataset of compilable C functions
+ ([#210](https://github.com/facebookresearch/CompilerGym/pull/210)). See [the
+ docs](https://facebookresearch.github.io/CompilerGym/llvm/index.html#datasets)
+ for an overview.
+- `CompilerEnv` now takes an optional `Logger` instance at construction time for
+ fine-grained control over logging output
+ ([#187](https://github.com/facebookresearch/CompilerGym/pull/187)).
+- [LLVM] The ModuleID and source_filename of LLVM-IR modules are now anonymized
+ to prevent unintentional overfitting to benchmarks by name
+ ([#171](https://github.com/facebookresearch/CompilerGym/pull/171)).
+- [docs] We have added a [Feature
+ Stability](https://facebookresearch.github.io/CompilerGym/about.html#feature-stability)
+ section to the documentation
+ ([#196](https://github.com/facebookresearch/CompilerGym/pull/196)).
+- Numerous bug fixes and improvements.
+
+Please use this checklist when updating code for the previous CompilerGym release:
+
+* Review code that accesses the `env.benchmark` property and update to
+ `env.benchmark.uri` if a string name is required. Setting this attribute by
+ string (`env.benchmark = "benchmark://a-v0/b"`) and comparison to string types
+ (`env.benchmark == "benchmark://a-v0/b"`) still work.
+* Review code that calls `env.reset()` without first setting a benchmark.
+ Previously, calling `env.reset()` would select a random benchmark. Now,
+ `env.reset()` always selects the last used benchmark, or a predetermined
+ default if none is specified.
+* Review code that relies on `env.benchmark` being `None` to select benchmarks
+ randomly. Now, `env.benchmark` is always set to the previously used benchmark,
+ or a predetermined default benchmark if none has been specified. Setting
+ `env.benchmark = None` will raise an error. Select a benchmark randomly by
+ sampling from the `env.datasets.benchmark_uris()` iterator.
+* Remove calls to `env.require_dataset()` and related operations. These are no
+ longer required.
+* Remove accesses to `env.benchmarks`. An iterator over available benchmark URIs
+ is now available at `env.datasets.benchmark_uris()`, but the list of URIs
+ cannot be relied on to be fully enumerable (the LLVM environments have over
+ 2^32 URIs).
+* Review code that accesses `env.observation_space` and update to
+ `env.observation_space_spec` where necessary
+ ([#228](https://github.com/facebookresearch/CompilerGym/pull/228)).
+* Update compiler service implementations to support the updated RPC interface
+ by removing the deprecated `GetBenchmarks` RPC endpoint and replacing it with
+ `Dataset` classes. See the [example
+ service](https://github.com/facebookresearch/CompilerGym/tree/development/examples/example_compiler_gym_service)
+ for details.
+* [LLVM] Update references to the `poj104-v0` dataset to `poj104-v1`.
+* [LLVM] Update references to the `cBench-v1` dataset to `cbench-v1`.
+
## Release 0.1.7 (2021-04-01)
This release introduces [public
@@ -46,11 +133,11 @@ semantics validation, and improving the datasets. Many thanks to @JD-at-work,
- Added default reward spaces for `CompilerEnv` that are derived from scalar
observations (thanks @bwasti!)
- Added a new Q learning example (thanks @JD-at-work!).
-- *Deprecation:* The next release v0.1.5 will introduce a new datasets API that
- is easier to use and more flexible. In preparation for this, the `Dataset`
- class has been renamed to `LegacyDataset`, the following dataset operations
- have been marked deprecated: `activate()`, `deactivate()`, and `delete()`. The
- `GetBenchmarks()` RPC interface method has also been marked deprecated..
+- *Deprecation:* The v0.1.8 release will introduce a new datasets API that is
+ easier to use and more flexible. In preparation for this, the `Dataset` class
+ has been renamed to `LegacyDataset`, the following dataset operations have
+ been marked deprecated: `activate()`, `deactivate()`, and `delete()`. The
+ `GetBenchmarks()` RPC interface method has also been marked deprecated.
- [llvm] Improved semantics validation using LLVM's memory, thread, address, and
undefined behavior sanitizers.
- Numerous bug fixes and improvements.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 86f20bd9e..0d327c510 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,16 @@
-# Contributing
+# Contributing
+
+**Table of Contents**
+
+- [How to Contribute](#how-to-contribute)
+- [Pull Requests](#pull-requests)
+- [Leaderboard Submissions](#leaderboard-submissions)
+- [Code Style](#code-style)
+- [Contributor License Agreement ("CLA")](#contributor-license-agreement-cla)
+
+---
+
+## How to Contribute
We want to make contributing to CompilerGym as easy and transparent
as possible. The most helpful ways to contribute are:
@@ -16,8 +28,10 @@ as possible. The most helpful ways to contribute are:
* Pull requests. Please see below for details. The easiest way to get stuck
is to grab an [unassigned "Good first issue"
ticket](https://github.com/facebookresearch/CompilerGym/issues?q=is%3Aopen+is%3Aissue+no%3Aassignee+label%3A%22Good+first+issue%22)!
- * Add new features not on the roadmap. Examples could include adding support
- for new compilers, producing research results using CompilerGym, etc.
+ * Add new features not on [the
+ roadmap](https://facebookresearch.github.io/CompilerGym/about.html#roadmap).
+ Examples could include adding support for new compilers, producing research
+ results using CompilerGym, etc.
## Pull Requests
@@ -32,9 +46,9 @@ We actively welcome your pull requests.
3. If you've added code that should be tested, add tests.
4. If you've changed APIs, update the [documentation](/docs/source).
5. Ensure the `make test` suite passes.
-6. Make sure your code lints (see "Code Style" below).
-7. If you haven't already, complete the Contributor License Agreement
- ("CLA").
+6. Make sure your code lints (see [Code Style](#code-style) below).
+7. If you haven't already, complete the [Contributor License Agreement
+ ("CLA")](#contributor-license-agreement-cla).
## Leaderboard Submissions
@@ -49,12 +63,13 @@ and file a [Pull Request](#pull-requests). Please include:
3. A write-up of your approach. You may use the
[submission template](/leaderboard/SUBMISSION_TEMPLATE.md) as a guide.
-We do not require that you submit the source code for your approach. Once you
-submit your pull request we will validate your results CSV files and may ask
-clarifying questions if we feel that those would be useful to improve
-reproducibility. Please [take a look
-here](https://github.com/facebookresearch/CompilerGym/pull/117) for an example
-of a well-formed pull request submission.
+Please make sure to update to the latest CompilerGym release prior to
+submission. We do not require that you submit the source code for your approach,
+though we encourage that you make it publicly available. Once you submit your
+pull request we will validate your results CSV files and may ask clarifying
+questions if we feel that those would be useful to improve reproducibility.
+[Take a look here](https://github.com/facebookresearch/CompilerGym/pull/117) for
+an example of a well-formed pull request submission.
## Code Style
@@ -69,15 +84,17 @@ is simple:
style](https://google.github.io/styleguide/cppguide.html) with 100
character line length and `camelCaseFunctionNames()`.
-We use [pre-commit](/.pre-commit-config.yaml) to format our code to
-enforce these rules. Before submitting pull requests, please run
-pre-commit to ensure the code is correctly formatted.
+We use [pre-commit](https://pre-commit.com/) to ensure that code is formatted
+prior to committing. Before submitting pull requests, please run pre-commit. See
+the [config file](/.pre-commit-config.yaml) for installation and usage
+instructions.
Other common sense rules we encourage are:
* Prefer descriptive names over short ones.
* Split complex code into small units.
* When writing new features, add tests.
+* Make tests deterministic.
* Prefer easy-to-use code over easy-to-read, and easy-to-read code over
easy-to-write.
diff --git a/Makefile b/Makefile
index 85d53ec05..9d852972d 100644
--- a/Makefile
+++ b/Makefile
@@ -34,6 +34,12 @@ Post-installation Tests
usually not needed for interactive development since `make test` runs
the same tests without having to install anything.
+ make install-test-cov
+ The same as `make install-test`, but with python test coverage
+ reporting. A summary of test coverage is printed at the end of execution
+ and the full details are recorded in a coverage.xml file in the project
+ root directory.
+
make install-fuzz
Run the fuzz testing suite against an installed CompilerGym package.
Fuzz tests are tests that generate their own inputs and run in a loop
@@ -185,7 +191,7 @@ docs/source/contributing.rst: CONTRIBUTING.md
docs/source/installation.rst: README.md
echo "..\n Generated from $<. Do not edit!\n" > $@
- sed -n '/^## Installation/,$$p' $< | sed -n '/^## Trying/q;p' | $(PANDOC) --from=markdown --to=rst >> $@
+ sed -n '/^## Installation/,$$p' $< | sed -n '/^### Building/q;p' | $(PANDOC) --from=markdown --to=rst >> $@
GENERATED_DOCS := \
docs/source/changelog.rst \
@@ -215,20 +221,30 @@ test:
itest:
$(IBAZEL) $(BAZEL_OPTS) test $(BAZEL_TEST_OPTS) //...
-install-test-datasets:
- cd .. && $(PYTHON) -m compiler_gym.bin.datasets --env=llvm-v0 --download=cBench-v1 >/dev/null
-install-test: install-test-datasets
+# Since we can't run compiler_gym from the project root we need to jump through
+# some hoops to run pytest "out of tree" by creating an empty directory and
+# symlinking the test directory into it so that pytest can be invoked.
+define run_pytest_suite
mkdir -p /tmp/compiler_gym/wheel_tests
- rm -f /tmp/compiler_gym/wheel_tests/tests
+ rm -f /tmp/compiler_gym/wheel_tests/tests /tmp/compiler_gym/wheel_tests/tox.ini
ln -s $(ROOT)/tests /tmp/compiler_gym/wheel_tests
- cd /tmp/compiler_gym/wheel_tests && pytest tests -n auto -k "not fuzz"
+ ln -s $(ROOT)/tox.ini /tmp/compiler_gym/wheel_tests
+ cd /tmp/compiler_gym/wheel_tests && pytest tests $(1) --benchmark-disable -n auto -k "not fuzz"
+endef
+
+install-test:
+ $(call run_pytest_suite,)
+
+install-test-cov:
+ $(call run_pytest_suite,--cov=compiler_gym --cov-report=xml)
+ @mv /tmp/compiler_gym/wheel_tests/coverage.xml .
# The minimum number of seconds to run the fuzz tests in a loop for. Override
# this at the commandline, e.g. `FUZZ_SECONDS=1800 make fuzz`.
FUZZ_SECONDS ?= 300
-install-fuzz: install-test-datasets
+install-fuzz:
mkdir -p /tmp/compiler_gym/wheel_fuzz_tests
rm -f /tmp/compiler_gym/wheel_fuzz_tests/tests
ln -s $(ROOT)/tests /tmp/compiler_gym/wheel_fuzz_tests
diff --git a/README.md b/README.md
index 95504eeb2..f6ee38e96 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,10 @@
+
+
+
+
CompilerGym is a toolkit for exposing compiler optimization problems
for reinforcement learning. It allows machine learning researchers to
@@ -32,6 +36,7 @@ developers to expose new optimization problems for AI.
**Table of Contents**
+- [Features](#features)
- [Getting Started](#getting-started)
- [Installation](#installation)
- [Building from Source](#building-from-source)
@@ -42,6 +47,40 @@ developers to expose new optimization problems for AI.
- [Citation](#citation)
+# Features
+
+With CompilerGym, building ML models for compiler research problems is as easy
+as building ML models to play video games. Here are some highlights of key
+features:
+
+* **API:** uses the popular [Gym](https://gym.openai.com/) interface from OpenAI
+ — use Python to write your agent.
+
+* **Datasets:** wraps real world programs (C++ programs, TensorFlow programs,
+ programs from Github, etc.) and a mainstream compiler
+ ([LLVM](https://llvm.org/)), providing millions of programs for training.
+
+* **Tasks and Actions:** interfaces the [LLVM](https://llvm.org/) compiler for
+ one compiler research problem: phase ordering (more to come). It has a large
+ discrete action space.
+
+* **Representations:** provides raw representations of programs, as well as
+ multiple kinds of pre-computed features: you can focus on end-to-end deep
+ learning or features + boosted trees, all the way up to graph models.
+
+* **Rewards:** provides appropriate reward functions and loss functions out of
+ the box.
+
+* **Testing:** provides a validation process for correctness of results.
+
+* **Baselines:** provides some baselines and reports their performance.
+
+* **Competition:** provides [leaderboards](#leaderboards) for you to submit your
+ results.
+
+For a glimpse of what's to come, check out [our
+roadmap](https://github.com/facebookresearch/CompilerGym/projects/1).
+
# Getting Started
Starting with CompilerGym is simple. If you not already familiar with the gym
@@ -83,7 +122,7 @@ Now proceed to [All platforms](#all-platforms) below.
On debian-based linux systems, install the required toolchain using:
```sh
-sudo apt install clang libtinfo5 libjpeg-dev patchelf
+sudo apt install clang-9 libtinfo5 libjpeg-dev patchelf
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.7.5/bazelisk-linux-amd64 -O bazel
chmod +x bazel && mkdir -p ~/.local/bin && mv -v bazel ~/.local/bin
export PATH="$HOME/.local/bin:$PATH"
@@ -106,12 +145,15 @@ Then clone the CompilerGym source code using:
git clone https://github.com/facebookresearch/CompilerGym.git
cd CompilerGym
-Install the python development dependencies using:
+There are two primary git branches: `stable` tracks the latest release;
+`development` is for bleeding edge features that may not yet be mature. Checkout
+your preferred branch and install the python development dependencies using:
+ git checkout stable
make init
The `make init` target only needs to be run once on initial setup, or when
-upgrading to a different CompilerGym release.
+pulling remote changes to the CompilerGym repository.
Run the test suite to confirm that everything is working:
@@ -140,15 +182,18 @@ In Python, import `compiler_gym` to use the environments:
>>> import gym
>>> import compiler_gym # imports the CompilerGym environments
>>> env = gym.make("llvm-autophase-ic-v0") # starts a new environment
->>> env.require_dataset("npb-v0") # downloads a set of programs
->>> env.reset() # starts a new compilation session with a random program
+>>> env.benchmark = "benchmark://cbench-v1/qsort" # select a program to compile
+>>> env.reset() # starts a new compilation session
>>> env.render() # prints the IR of the program
>>> env.step(env.action_space.sample()) # applies a random optimization, updates state/reward/actions
```
-See the
-[documentation website](http://facebookresearch.github.io/CompilerGym/) for
-tutorials, further details, and API reference.
+See the [documentation website](http://facebookresearch.github.io/CompilerGym/)
+for tutorials, further details, and API reference. Our
+[roadmap](https://facebookresearch.github.io/CompilerGym/about.html#roadmap) of
+planned features is public, and the
+[changelog](https://github.com/facebookresearch/CompilerGym/blob/development/CHANGELOG.md)
+summarizes shipped features.
# Leaderboards
@@ -169,7 +214,7 @@ count achieved scaled to the reduction achieved by LLVM's builtin `-Oz`
pipeline.
This leaderboard tracks the results achieved by algorithms on the `llvm-ic-v0`
-environment on the 23 benchmarks in the `cBench-v1` dataset.
+environment on the 23 benchmarks in the `cbench-v1` dataset.
| Author | Algorithm | Links | Date | Walltime (mean) | Codesize Reduction (geomean) |
| --- | --- | --- | --- | --- | --- |
@@ -178,7 +223,9 @@ environment on the 23 benchmarks in the `cBench-v1` dataset.
| Facebook | Greedy search | [write-up](leaderboard/llvm_instcount/e_greedy/README.md), [results](leaderboard/llvm_instcount/e_greedy/results_e0.csv) | 2021-03 | 169.237s | 1.055× |
| Facebook | Random search (t=60) | [write-up](leaderboard/llvm_instcount/random_search/README.md), [results](leaderboard/llvm_instcount/random_search/results_p125_t60.csv) | 2021-03 | 91.215s | 1.045× |
| Facebook | e-Greedy search (e=0.1) | [write-up](leaderboard/llvm_instcount/e_greedy/README.md), [results](leaderboard/llvm_instcount/e_greedy/results_e10.csv) | 2021-03 | 152.579s | 1.041× |
+| Jiadong Guo | Tabular Q (N=5000, H=10) | [write-up](leaderboard/llvm_instcount/tabular_q/README.md), [results](leaderboard/llvm_instcount/tabular_q/results-H10-N5000.csv) | 2021-04 | 2534.305 | 1.036× |
| Facebook | Random search (t=10) | [write-up](leaderboard/llvm_instcount/random_search/README.md), [results](leaderboard/llvm_instcount/random_search/results_p125_t10.csv) | 2021-03 | **42.939s** | 1.031× |
+| Jiadong Guo | Tabular Q (N=2000, H=5) | [write-up](leaderboard/llvm_instcount/tabular_q/README.md), [results](leaderboard/llvm_instcount/tabular_q/results-H5-N2000.csv) | 2021-04 | 694.105 | 0.988× |
# Contributing
diff --git a/VERSION b/VERSION
index 11808190d..699c6c6d4 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.1.7
+0.1.8
diff --git a/tests/benchmarks/BUILD b/benchmarks/BUILD
similarity index 91%
rename from tests/benchmarks/BUILD
rename to benchmarks/BUILD
index c6dda436b..f57da9c96 100644
--- a/tests/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -10,6 +10,7 @@ py_test(
shard_count = 8,
deps = [
"//compiler_gym",
+ "//examples/example_compiler_gym_service",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
@@ -27,7 +28,9 @@ py_binary(
py_test(
name = "parallelization_load_test_test",
+ timeout = "moderate",
srcs = ["parallelization_load_test_test.py"],
+ flaky = 1,
deps = [
":parallelization_load_test",
"//tests:test_main",
diff --git a/benchmarks/bench_test.py b/benchmarks/bench_test.py
new file mode 100644
index 000000000..9a5349358
--- /dev/null
+++ b/benchmarks/bench_test.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""Microbenchmarks for CompilerGym environments.
+
+To run these benchmarks an optimized build using bazel:
+
+ $ bazel test -c opt --test_output=streamed //benchmarks:bench_test
+
+A record of the benchmark results is stored in
+/tmp/compiler_gym/pytest_benchmark//_bench_test.json. Compare
+multiple runs using:
+
+ $ pytest-benchmark compare --group-by=name --sort=fullname \
+ /tmp/compiler_gym/pytest_benchmark/*/*_bench_test.json
+"""
+import gym
+import pytest
+
+import examples.example_compiler_gym_service as dummy
+from compiler_gym.envs import CompilerEnv, LlvmEnv, llvm
+from compiler_gym.service import CompilerGymServiceConnection
+from tests.pytest_plugins.llvm import OBSERVATION_SPACE_NAMES, REWARD_SPACE_NAMES
+from tests.test_main import main
+
+
+@pytest.fixture(
+ params=["llvm-v0", "example-cc-v0", "example-py-v0"],
+ ids=["llvm", "dummy-cc", "dummy-py"],
+)
+def env_id(request) -> str:
+ yield request.param
+
+
+@pytest.fixture(
+ params=["llvm-v0", "example-cc-v0", "example-py-v0"],
+ ids=["llvm", "dummy-cc", "dummy-py"],
+)
+def env(request) -> CompilerEnv:
+ yield request.param
+
+
+@pytest.mark.parametrize(
+ "env_id",
+ ["llvm-v0", "example-cc-v0", "example-py-v0"],
+ ids=["llvm", "dummy-cc", "dummy-py"],
+)
+def test_make_local(benchmark, env_id):
+ benchmark(lambda: gym.make(env_id).close())
+
+
+@pytest.mark.parametrize(
+ "args",
+ [
+ (llvm.LLVM_SERVICE_BINARY, LlvmEnv),
+ (dummy.EXAMPLE_CC_SERVICE_BINARY, CompilerEnv),
+ (dummy.EXAMPLE_PY_SERVICE_BINARY, CompilerEnv),
+ ],
+ ids=["llvm", "dummy-cc", "dummy-py"],
+)
+def test_make_service(benchmark, args):
+ service_binary, env_class = args
+ service = CompilerGymServiceConnection(service_binary)
+ try:
+ benchmark(lambda: env_class(service=service.connection.url).close())
+ finally:
+ service.close()
+
+
+@pytest.mark.parametrize(
+ "make_env",
+ [
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/crc32"),
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/jpeg-d"),
+ lambda: gym.make("example-cc-v0"),
+ lambda: gym.make("example-py-v0"),
+ ],
+ ids=["llvm;fast-benchmark", "llvm;slow-benchmark", "dummy-cc", "dummy-py"],
+)
+def test_reset(benchmark, make_env: CompilerEnv):
+ with make_env() as env:
+ benchmark(env.reset)
+
+
+@pytest.mark.parametrize(
+ "args",
+ [
+ (
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/crc32"),
+ "-globaldce",
+ ),
+ (lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/crc32"), "-gvn"),
+ (
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/jpeg-d"),
+ "-globaldce",
+ ),
+ (
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/jpeg-d"),
+ "-gvn",
+ ),
+ (lambda: gym.make("example-cc-v0"), "a"),
+ (lambda: gym.make("example-py-v0"), "a"),
+ ],
+ ids=[
+ "llvm;fast-benchmark;fast-action",
+ "llvm;fast-benchmark;slow-action",
+ "llvm;slow-benchmark;fast-action",
+ "llvm;slow-benchmark;slow-action",
+ "dummy-cc",
+ "dummy-py",
+ ],
+)
+def test_step(benchmark, args):
+ make_env, action_name = args
+ with make_env() as env:
+ env.reset()
+ action = env.action_space[action_name]
+ benchmark(env.step, action)
+
+
+_args = dict(
+ {
+ f"llvm;{obs}": (lambda: gym.make("llvm-v0", benchmark="cbench-v1/qsort"), obs)
+ for obs in OBSERVATION_SPACE_NAMES
+ },
+ **{
+ "dummy-cc": (lambda: gym.make("example-cc-v0"), "ir"),
+ "dummy-py": (lambda: gym.make("example-py-v0"), "features"),
+ },
+)
+
+
+@pytest.mark.parametrize("args", _args.values(), ids=_args.keys())
+def test_observation(benchmark, args):
+ make_env, observation_space = args
+ with make_env() as env:
+ env.reset()
+ benchmark(lambda: env.observation[observation_space])
+
+
+_args = dict(
+ {
+ f"llvm;{reward}": (
+ lambda: gym.make("llvm-v0", benchmark="cbench-v1/qsort"),
+ reward,
+ )
+ for reward in REWARD_SPACE_NAMES
+ },
+ **{
+ "dummy-cc": (lambda: gym.make("example-cc-v0"), "runtime"),
+ "dummy-py": (lambda: gym.make("example-py-v0"), "runtime"),
+ },
+)
+
+
+@pytest.mark.parametrize("args", _args.values(), ids=_args.keys())
+def test_reward(benchmark, args):
+ make_env, reward_space = args
+ with make_env() as env:
+ env.reset()
+ benchmark(lambda: env.reward[reward_space])
+
+
+@pytest.mark.parametrize(
+ "make_env",
+ [
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/crc32"),
+ lambda: gym.make("llvm-autophase-ic-v0", benchmark="cbench-v1/jpeg-d"),
+ # TODO: Example service does not yet support fork() operator.
+ # lambda: gym.make("example-cc-v0"),
+ # lambda: gym.make("example-py-v0"),
+ ],
+ ids=["llvm;fast-benchmark", "llvm;slow-benchmark"],
+)
+def test_fork(benchmark, make_env):
+ with make_env() as env:
+ env.reset()
+ benchmark(lambda: env.fork().close())
+
+
+if __name__ == "__main__":
+ main(
+ extra_pytest_args=[
+ "--benchmark-storage=/tmp/compiler_gym/pytest_benchmark",
+ "--benchmark-save=bench_test",
+ "--benchmark-sort=name",
+ "-x",
+ ],
+ debug_level=0,
+ )
diff --git a/tests/benchmarks/parallelization_load_test.py b/benchmarks/parallelization_load_test.py
similarity index 100%
rename from tests/benchmarks/parallelization_load_test.py
rename to benchmarks/parallelization_load_test.py
diff --git a/tests/benchmarks/parallelization_load_test_test.py b/benchmarks/parallelization_load_test_test.py
similarity index 81%
rename from tests/benchmarks/parallelization_load_test_test.py
rename to benchmarks/parallelization_load_test_test.py
index 979990d0f..35d73ad27 100644
--- a/tests/benchmarks/parallelization_load_test_test.py
+++ b/benchmarks/parallelization_load_test_test.py
@@ -2,14 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-"""Smoke test for //tests/benchmarks:parallelization_load_test."""
+"""Smoke test for //benchmarks:parallelization_load_test."""
from pathlib import Path
from absl import flags
+from benchmarks.parallelization_load_test import main as load_test
from compiler_gym.util.capture_output import capture_output
-from tests.benchmarks.parallelization_load_test import main as load_test
-from tests.pytest_plugins.common import skip_on_ci
+from tests.pytest_plugins.common import set_command_line_flags, skip_on_ci
from tests.test_main import main
FLAGS = flags.FLAGS
@@ -21,12 +21,11 @@
def test_load_test(env, tmpwd):
del env # Unused.
del tmpwd # Unused.
- FLAGS.unparse_flags()
- FLAGS(
+ set_command_line_flags(
[
"arv0",
"--env=llvm-v0",
- "--benchmark=cBench-v1/crc32",
+ "--benchmark=cbench-v1/crc32",
"--max_nproc=3",
"--nproc_increment=1",
"--num_steps=2",
diff --git a/compiler_gym/BUILD b/compiler_gym/BUILD
index deeaf4ff6..8b3da5205 100644
--- a/compiler_gym/BUILD
+++ b/compiler_gym/BUILD
@@ -24,6 +24,10 @@ py_library(
name = "compiler_env_state",
srcs = ["compiler_env_state.py"],
visibility = ["//compiler_gym:__subpackages__"],
+ deps = [
+ "//compiler_gym/datasets:uri",
+ "//compiler_gym/util",
+ ],
)
py_library(
@@ -54,6 +58,7 @@ py_library(
srcs = ["validate.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
+ ":validation_error",
":validation_result",
"//compiler_gym/envs:compiler_env",
"//compiler_gym/spaces",
@@ -61,12 +66,19 @@ py_library(
],
)
+py_library(
+ name = "validation_error",
+ srcs = ["validation_error.py"],
+ visibility = ["//compiler_gym:__subpackages__"],
+)
+
py_library(
name = "validation_result",
srcs = ["validation_result.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
":compiler_env_state",
+ ":validation_error",
"//compiler_gym/util",
],
)
diff --git a/compiler_gym/__init__.py b/compiler_gym/__init__.py
index 33ba6a6ea..0110aa9fd 100644
--- a/compiler_gym/__init__.py
+++ b/compiler_gym/__init__.py
@@ -29,7 +29,11 @@
"compiler_gym` will work."
) from e
-from compiler_gym.compiler_env_state import CompilerEnvState
+from compiler_gym.compiler_env_state import (
+ CompilerEnvState,
+ CompilerEnvStateReader,
+ CompilerEnvStateWriter,
+)
from compiler_gym.envs import COMPILER_GYM_ENVS, CompilerEnv, observation_t, step_t
from compiler_gym.random_search import random_search
from compiler_gym.util.debug_util import (
@@ -44,7 +48,8 @@
transient_cache_path,
)
from compiler_gym.validate import validate_states
-from compiler_gym.validation_result import ValidationError, ValidationResult
+from compiler_gym.validation_error import ValidationError
+from compiler_gym.validation_result import ValidationResult
# The top-level compiler_gym API.
__all__ = [
@@ -53,6 +58,8 @@
"COMPILER_GYM_ENVS",
"CompilerEnv",
"CompilerEnvState",
+ "CompilerEnvStateWriter",
+ "CompilerEnvStateReader",
"download",
"get_debug_level",
"get_logging_level",
diff --git a/compiler_gym/bin/BUILD b/compiler_gym/bin/BUILD
index 3ae0fb971..92cfc45a8 100644
--- a/compiler_gym/bin/BUILD
+++ b/compiler_gym/bin/BUILD
@@ -22,7 +22,8 @@ py_binary(
srcs = ["datasets.py"],
visibility = ["//visibility:public"],
deps = [
- "//compiler_gym/datasets:dataset",
+ ":service",
+ "//compiler_gym/datasets",
"//compiler_gym/envs",
"//compiler_gym/util",
"//compiler_gym/util/flags:env_from_flags",
@@ -39,7 +40,6 @@ py_binary(
"//compiler_gym/util",
"//compiler_gym/util/flags:benchmark_from_flags",
"//compiler_gym/util/flags:env_from_flags",
- "//compiler_gym/util/flags:ls_benchmark",
],
)
@@ -60,7 +60,6 @@ py_binary(
"//compiler_gym:random_search",
"//compiler_gym/util/flags:benchmark_from_flags",
"//compiler_gym/util/flags:env_from_flags",
- "//compiler_gym/util/flags:ls_benchmark",
"//compiler_gym/util/flags:nproc",
"//compiler_gym/util/flags:output_dir",
],
@@ -83,6 +82,7 @@ py_binary(
srcs = ["service.py"],
visibility = ["//visibility:public"],
deps = [
+ "//compiler_gym/datasets",
"//compiler_gym/envs",
"//compiler_gym/spaces",
"//compiler_gym/util",
@@ -97,7 +97,6 @@ py_binary(
deps = [
"//compiler_gym:validate",
"//compiler_gym/util",
- "//compiler_gym/util/flags:dataset",
"//compiler_gym/util/flags:env_from_flags",
"//compiler_gym/util/flags:nproc",
],
diff --git a/compiler_gym/bin/datasets.py b/compiler_gym/bin/datasets.py
index c6a2495ae..cde0dffce 100644
--- a/compiler_gym/bin/datasets.py
+++ b/compiler_gym/bin/datasets.py
@@ -6,11 +6,8 @@
.. code-block::
- $ python -m compiler_gym.bin.datasets --env= [command...]
-
-Where :code:`command` is one of :code:`--download=`,
-:code:`--activate=`, :code:`--deactivate=`,
-and :code:`--delete=`.
+ $ python -m compiler_gym.bin.datasets --env= \
+ [--download=] [--delete=]
Listing installed datasets
@@ -22,38 +19,21 @@
.. code-block::
$ python -m comiler_gym.bin.benchmarks --env=llvm-v0
- llvm-v0 benchmarks site dir: /home/user/.local/share/compiler_gym/llvm/10.0.0/bitcode_benchmarks
-
- +-------------------+--------------+-----------------+----------------+
- | Active Datasets | License | #. Benchmarks | Size on disk |
- +===================+==============+=================+================+
- | cBench-v1 | BSD 3-Clause | 23 | 10.1 MB |
- +-------------------+--------------+-----------------+----------------+
- | Total | | 23 | 10.1 MB |
- +-------------------+--------------+-----------------+----------------+
- These benchmarks are ready for use. Deactivate them using `--deactivate=`.
-
- +---------------------+-----------+-----------------+----------------+
- | Inactive Datasets | License | #. Benchmarks | Size on disk |
- +=====================+===========+=================+================+
- | Total | | 0 | 0 Bytes |
- +---------------------+-----------+-----------------+----------------+
- These benchmarks may be activated using `--activate=`.
-
- +------------------------+---------------------------------+-----------------+----------------+
- | Downloadable Dataset | License | #. Benchmarks | Size on disk |
- +========================+=================================+=================+================+
- | blas-v0 | BSD 3-Clause | 300 | 4.0 MB |
- +------------------------+---------------------------------+-----------------+----------------+
- | polybench-v0 | BSD 3-Clause | 27 | 162.6 kB |
- +------------------------+---------------------------------+-----------------+----------------+
- These benchmarks may be installed using `--download= --activate=`.
+
+ +-------------------+---------------------+-----------------+----------------+
+ | Active Datasets | Description | #. Benchmarks | Size on disk |
+ +===================+=====================+=================+================+
+ | cbench-v1 | Runnable C programs | 23 | 10.1 MB |
+ +-------------------+---------------------+-----------------+----------------+
+ | Total | | 23 | 10.1 MB |
+ +-------------------+---------------------+-----------------+----------------+
Downloading datasets
--------------------
-Use :code:`--download` to download a dataset from the list of available datasets:
+Use :code:`--download` to download a dataset from the list of available
+datasets:
.. code-block::
@@ -73,24 +53,12 @@
$ python -m comiler_gym.bin.benchmarks --env=llvm-v0 --download_all
-:code:`--download` accepts the URL of any :code:`.tar.bz2` file to support custom datasets:
-
-.. code-block::
-
- $ python -m comiler_gym.bin.benchmarks --env=llvm-v0 --download=https://example.com/dataset.tar.bz2
-
Or use the :code:`file:///` URI to install a local archive file:
.. code-block::
$ python -m compiler_gym.bin.benchmarks --env=llvm-v0 --download=file:////tmp/dataset.tar.bz2
-The list of datasets that are available to download may be extended by calling
-:meth:`CompilerEnv.register_dataset() `
-on a :code:`CompilerEnv` instance.
-
-To programmatically download datasets, see
-:meth:`CompilerEnv.require_dataset() `.
Activating and deactivating datasets
------------------------------------
@@ -100,12 +68,12 @@
This be useful if you have many datasets downloaded and you would to limit the
benchmarks that can be selected randomly by an environment.
-Activate or deactivate datasets using the :code:`--activate` and :code:`--deactivate`
-flags, respectively:
+Activate or deactivate datasets using the :code:`--activate` and
+:code:`--deactivate` flags, respectively:
.. code-block::
- $ python -m comiler_gym.bin.benchmarks --env=llvm-v0 --activate=npb-v0,github-v0 --deactivate=cBench-v1
+ $ python -m comiler_gym.bin.benchmarks --env=llvm-v0 --activate=npb-v0,github-v0 --deactivate=cbench-v1
The :code:`--activate_all` and :code:`--deactivate_all` flags can be used as a
shortcut to activate or deactivate every downloaded:
@@ -131,23 +99,14 @@
A :code:`--delete_all` flag can be used to delete all of the locally installed
datasets.
"""
-import os
import sys
-from pathlib import Path
-from typing import Tuple
-import humanize
from absl import app, flags
+from deprecated.sphinx import deprecated
-from compiler_gym.datasets.dataset import (
- LegacyDataset,
- activate,
- deactivate,
- delete,
- require,
-)
+from compiler_gym.bin.service import summarize_datasets
+from compiler_gym.datasets.dataset import activate, deactivate, delete
from compiler_gym.util.flags.env_from_flags import env_from_flags
-from compiler_gym.util.tabulate import tabulate
flags.DEFINE_list(
"download",
@@ -175,31 +134,15 @@
FLAGS = flags.FLAGS
-def get_count_and_size_of_directory_contents(root: Path) -> Tuple[int, int]:
- """Return the number of files and combined size of a directory."""
- count, size = 0, 0
- for root, _, files in os.walk(str(root)):
- count += len(files)
- size += sum(os.path.getsize(f"{root}/{file}") for file in files)
- return count, size
-
-
-def enumerate_directory(name: str, path: Path):
- rows = []
- for path in path.iterdir():
- if not path.is_file() or not path.name.endswith(".json"):
- continue
- dataset = LegacyDataset.from_json_file(path)
- rows.append(
- (dataset.name, dataset.license, dataset.file_count, dataset.size_bytes)
- )
- rows.append(("Total", "", sum(r[2] for r in rows), sum(r[3] for r in rows)))
- return tabulate(
- [(n, l, humanize.intcomma(f), humanize.naturalsize(s)) for n, l, f, s in rows],
- headers=(name, "License", "#. Benchmarks", "Size on disk"),
- )
-
-
+@deprecated(
+ version="0.1.8",
+ reason=(
+ "Command-line management of datasets is deprecated. Please use "
+ ":mod:`compiler_gym.bin.service` to print a tabular overview of the "
+ "available datasets. For management of datasets, use the "
+ ":class:`env.datasets ` property."
+ ),
+)
def main(argv):
"""Main entry point."""
if len(argv) != 1:
@@ -207,28 +150,20 @@ def main(argv):
env = env_from_flags()
try:
- if not env.datasets_site_path:
- raise app.UsageError("Environment has no benchmarks site path")
-
- env.datasets_site_path.mkdir(parents=True, exist_ok=True)
- env.inactive_datasets_site_path.mkdir(parents=True, exist_ok=True)
-
invalidated_manifest = False
for name_or_url in FLAGS.download:
- require(env, name_or_url)
+ env.datasets.install(name_or_url)
if FLAGS.download_all:
- for dataset in env.available_datasets:
- require(env, dataset)
+ for dataset in env.datasets:
+ dataset.install()
for name in FLAGS.activate:
activate(env, name)
invalidated_manifest = True
if FLAGS.activate_all:
- for path in env.inactive_datasets_site_path.iterdir():
- activate(env, path.name)
invalidated_manifest = True
for name in FLAGS.deactivate:
@@ -236,8 +171,6 @@ def main(argv):
invalidated_manifest = True
if FLAGS.deactivate_all:
- for path in env.datasets_site_path.iterdir():
- deactivate(env, path.name)
invalidated_manifest = True
for name in FLAGS.delete:
@@ -246,41 +179,8 @@ def main(argv):
if invalidated_manifest:
env.make_manifest_file()
- print(f"{env.spec.id} benchmarks site dir: {env.datasets_site_path}")
- print()
- print(
- enumerate_directory("Active Datasets", env.datasets_site_path),
- )
- print(
- "These benchmarks are ready for use. Deactivate them using `--deactivate=`."
- )
- print()
- print(enumerate_directory("Inactive Datasets", env.inactive_datasets_site_path))
- print("These benchmarks may be activated using `--activate=`.")
- print()
- print(
- tabulate(
- sorted(
- [
- (
- d.name,
- d.license,
- humanize.intcomma(d.file_count),
- humanize.naturalsize(d.size_bytes),
- )
- for d in env.available_datasets.values()
- ]
- ),
- headers=(
- "Downloadable Dataset",
- "License",
- "#. Benchmarks",
- "Size on disk",
- ),
- )
- )
print(
- "These benchmarks may be installed using `--download= --activate=`."
+ summarize_datasets(env.datasets),
)
finally:
env.close()
diff --git a/compiler_gym/bin/manual_env.py b/compiler_gym/bin/manual_env.py
index dd1dce7a2..45df0b0a1 100644
--- a/compiler_gym/bin/manual_env.py
+++ b/compiler_gym/bin/manual_env.py
@@ -16,33 +16,11 @@
CompilerGym Shell Tutorial
**************************
-This program gives a basic shell through which many of commands from
-CompilerGym can be executed. CompilerGym provides a simple Python interface to
-various compiler functions, enabling programs to be compiled in different ways
-and to make queries about those programs. The goal is to have a simple system
-for machine learning in compilers.
-
-Downloading a Dataset
----------------------
-When entering the Shell, the environment (compiler choice) will have already
-been made on the command line. The benchmark or program to be compiled may not
-yet be set. Before setting a benchmark, however, the corresponding dataset must
-be downloaded. You may have already downloaded a dataset through the
-compiler_gym.bin.datasets command, but if not, you can do that from this shell.
-
-To download a dataset, call:
-
-.. code-block::
-
- compilergym:NO-BENCHMARK> require_dataset
-
-The command and the dataset name should tab-complete for you (most things will
-tab-complete in the shell). You can also see what datasets are available with
-this command:
-
-.. code-block::
-
- compilergym:NO-BENCHMARK> list_datasets
+This program gives a basic shell through which many of commands from CompilerGym
+can be executed. CompilerGym provides a simple Python interface to various
+compiler functions, enabling programs to be compiled in different ways and to
+make queries about those programs. The goal is to have a simple system for
+machine learning in compilers.
Setting a Benchmark, Reward and Observation
-------------------------------------------
@@ -51,34 +29,41 @@
.. code-block::
- compilergym:NO-BENCHMARK> set_benchmark
+ compiler_gym:cbench-v1/qsort> set_benchmark
When a benchmark is set, the prompt will update with the name of the benchmark.
Supposing that is "bench", then the prompt would be:
.. code-block::
- compilergym:bench>
+ compiler_gym:bench>
-The list of available benchmarks can be shown with:
+The list of available benchmarks can be shown with, though this is limited to
+the first 200 benchmarks:
.. code-block::
- compilergym:bench> list_benchmarks
+ compiler_gym:bench> list_benchmarks
+
+You can also see what datasets are available with this command:
+
+.. code-block::
+
+ compiler_gym:cbench-v1/qsort> list_datasets
The default reward and observation can be similarly set with:
.. code-block::
- compilergym:bench> set_default_reward
- compilergym:bench> set_default_observation
+ compiler_gym:bench> set_default_reward
+ compiler_gym:bench> set_default_observation
And lists of the choices are available with:
.. code-block::
- compilergym:bench> list_rewards
- compilergym:bench> list_observations
+ compiler_gym:bench> list_rewards
+ compiler_gym:bench> list_observations
The default rewards and observations will be reported every time an action is
taken. So, if, for example, you want to see how the instruction count of the
@@ -95,27 +80,26 @@
(currently an LLVM opt pass) on the intermediate representation of the program.
Each action acts on the result of the previous action and so on.
-So, for example, to apply first the 'tail call elimination' pass, then the
-'loop unrolling' pass we call two actions:
+So, for example, to apply first the 'tail call elimination' pass, then the 'loop
+unrolling' pass we call two actions:
.. code-block::
- compilergym:bench> action -tailcallelim
- compilergym:bench> action -loop-unroll
+ compiler_gym:bench> action -tailcallelim
+ compiler_gym:bench> action -loop-unroll
-Each action will report its default reward.
-Note that multiple actions can be placed on a single line, so that the above is
-equivalent to:
+Each action will report its default reward. Note that multiple actions can be
+placed on a single line, so that the above is equivalent to:
.. code-block::
- compilergym:bench> action -tailcallelim -loop-unroll
+ compiler_gym:bench> action -tailcallelim -loop-unroll
You can choose a random action, by using just a '-' as the action name:
.. code-block::
- compilergym:bench> action -
+ compiler_gym:bench> action -
Since an empty line on the shell repeats the last action, you can execute many
random actions by typing that line first then holding down return.
@@ -125,7 +109,7 @@
.. code-block::
- compilergym:bench> stack
+ compiler_gym:bench> stack
This will show for each action if it had an effect (as computed by the
underlying compiler), whether this terminated compiler, and what the per action
@@ -135,20 +119,20 @@
.. code-block::
- compilergym:bench> undo
+ compiler_gym:bench> undo
All actions in the stack can be undone at once by:
.. code-block::
- compilergym:bench> reset
+ compiler_gym:bench> reset
You can find out what the effect of each action would be by calling this
command:
.. code-block::
- compilergym:bench> try_all_actions
+ compiler_gym:bench> try_all_actions
This will show a table with the reward for each action, sorted by best first.
@@ -157,12 +141,12 @@
.. code-block::
- compilergym:bench> simplify_stack
+ compiler_gym:bench> simplify_stack
This will redo the entire stack, keeping only those actions which previously
gave good rewards. (Note this doesn't mean that the simplified stack will only
-have positive rewards, some negative actions may be necessary set up for a
-later positive reward.)
+have positive rewards, some negative actions may be necessary set up for a later
+positive reward.)
Current Status
--------------
@@ -174,20 +158,20 @@
.. code-block::
- compilergym:bench> reward
+ compiler_gym:bench> reward
You can see various observations with:
.. code-block::
- compilergym:bench> observation
+ compiler_gym:bench> observation
Finally, you can print the equivalent command line for achieving the same
behaviour as the actions through the standard system shell:
.. code-block::
- compilergym:bench> commandline
+ compiler_gym:bench> commandline
Searching
---------
@@ -198,7 +182,7 @@
.. code-block::
- compilergym:bench> action -
+ compiler_gym:bench> action -
Multiple steps can be taken by holding down the return key.
@@ -207,14 +191,14 @@
.. code-block::
- compilergym:bench> hill_climb
+ compiler_gym:bench> hill_climb
A simple greedy search tries all possible actions and takes the one with the
highest reward, stopping when no action has a positive reward:
.. code-block::
- compilergym:bench> greedy
+ compiler_gym:bench> greedy
Miscellaneous
-------------
@@ -222,7 +206,7 @@
.. code-block::
- compilergym:bench> breakpoint
+ compiler_gym:bench> breakpoint
Which drops into the python debugger. This is very useful if you want to see
what is going on internally. There is a 'self.env' object that represents the
@@ -232,7 +216,7 @@
.. code-block::
- compilergym:bench> exit
+ compiler_gym:bench> exit
Drops out of the shell. :code:`Ctrl-D` should have the same effect.
"""
@@ -240,10 +224,10 @@
import random
import readline
import sys
+from itertools import islice
from absl import app, flags
-import compiler_gym.util.flags.ls_benchmark # noqa Flag definition.
from compiler_gym.envs import CompilerEnv
from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags
from compiler_gym.util.flags.env_from_flags import env_from_flags
@@ -301,10 +285,12 @@ def __init__(self, env: CompilerEnv):
self.env = env
- self.init_benchmarks()
-
# Get the benchmarks
- self.benchmarks = sorted(self.env.benchmarks)
+ self.benchmarks = []
+ for dataset in self.env.datasets:
+ self.benchmarks += islice(dataset.benchmark_uris(), 50)
+ self.benchmarks.sort()
+
# Strip default benchmark:// protocol.
for i, benchmark in enumerate(self.benchmarks):
if benchmark.startswith("benchmark://"):
@@ -320,6 +306,12 @@ def __init__(self, env: CompilerEnv):
self.set_prompt()
+ def __del__(self):
+ """Tidy up in case postloop() is not called."""
+ if self.env:
+ self.env.close()
+ self.env = None
+
def do_tutorial(self, arg):
"""Print the turorial"""
print(tutorial)
@@ -335,24 +327,12 @@ def postloop(self):
self.env.close()
self.env = None
- def init_benchmarks(self):
- """Initialise the set of benchmarks"""
- # Get the benchmarks
- self.benchmarks = sorted(self.env.benchmarks)
- # Strip default benchmark:// protocol.
- for i, benchmark in enumerate(self.benchmarks):
- if benchmark.startswith("benchmark://"):
- self.benchmarks[i] = benchmark[len("benchmark://") :]
-
def set_prompt(self):
"""Set the prompt - shows the benchmark name"""
- if self.env.benchmark:
- bname = self.env.benchmark
- if bname.startswith("benchmark://"):
- bname = bname[len("benchmark://") :]
- else:
- bname = "NO-BENCHMARK"
- prompt = f"compilergym:{bname}>"
+ benchmark_name = self.env.benchmark.uri
+ if benchmark_name.startswith("benchmark://"):
+ benchmark_name = benchmark_name[len("benchmark://") :]
+ prompt = f"compiler_gym:{benchmark_name}>"
self.prompt = f"\n{emph(prompt)} "
def simple_complete(self, text, options):
@@ -363,38 +343,16 @@ def simple_complete(self, text, options):
return options
def get_datasets(self):
- """Get the list of available datasets"""
- return sorted([k for k in self.env.available_datasets])
+ """Get the list of datasets"""
+ return sorted([k.name for k in self.env.datasets.datasets()])
def do_list_datasets(self, arg):
- """List all of the available datasets"""
+ """List all of the datasets"""
print(", ".join(self.get_datasets()))
- def complete_require_dataset(self, text, line, begidx, endidx):
- """Complete the require_benchmark argument"""
- return self.simple_complete(text, self.get_datasets())
-
- def do_require_dataset(self, arg):
- """Require dataset
- The argument is the name of the dataset to require.
- """
- if self.get_datasets().count(arg):
- with Timer(f"Downloaded dataset {arg}"):
- self.env.require_dataset(arg)
- self.init_benchmarks()
- else:
- print("Unknown dataset, '" + arg + "'")
- print("Available datasets are listed with command, list_available_datasets")
-
def do_list_benchmarks(self, arg):
- """List all of the available benchmarks"""
- if not self.benchmarks:
- doc_root_url = "https://facebookresearch.github.io/CompilerGym/"
- install_url = doc_root_url + "getting_started.html#installing-benchmarks"
- print("No benchmarks available. See " + install_url)
- print("Datasets can be installed with command, require_dataset")
- else:
- print(", ".join(self.benchmarks))
+ """List the benchmarks"""
+ print(", ".join(self.benchmarks))
def complete_set_benchmark(self, text, line, begidx, endidx):
"""Complete the set_benchmark argument"""
@@ -409,27 +367,27 @@ def do_set_benchmark(self, arg):
Use '-' for a random benchmark.
"""
if arg == "-":
- arg = random.choice(self.benchmarks)
+ arg = self.env.datasets.benchmark().uri
print(f"set_benchmark {arg}")
- if self.benchmarks.count(arg):
+ try:
+ benchmark = self.env.datasets.benchmark(arg)
self.stack.clear()
# Set the current benchmark
with Timer() as timer:
- observation = self.env.reset(benchmark=arg)
+ observation = self.env.reset(benchmark=benchmark)
print(f"Reset {self.env.benchmark} environment in {timer}")
if self.env.observation_space and observation is not None:
print(
- f"Observation: {self.env.observation_space.to_string(observation)}"
+ f"Observation: {self.env.observation_space_spec.to_string(observation)}"
)
self.set_prompt()
-
- else:
+ except LookupError:
print("Unknown benchmark, '" + arg + "'")
- print("Bencmarks are listed with command, list_benchmarks")
+ print("Benchmarks are listed with command, list_benchmarks")
def get_actions(self):
"""Get the list of actions"""
@@ -451,10 +409,6 @@ def do_action(self, arg):
Tab completion will be used if available.
Use '-' for a random action.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if self.stack and self.stack[-1].done:
print(
"No action possible, last action ended by the environment with error:",
@@ -497,7 +451,7 @@ def do_action(self, arg):
# Print the observation, if available.
if self.env.observation_space and observation is not None:
print(
- f"Observation: {self.env.observation_space.to_string(observation)}"
+ f"Observation: {self.env.observation_space_spec.to_string(observation)}"
)
# Print the reward, if available.
@@ -557,10 +511,6 @@ def do_hill_climb(self, arg):
An argument, if given, should be the number of steps to take.
The search will try to improve the default reward. Please call set_default_reward if needed.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if not self.env.reward_space:
print("No default reward set. Call set_default_reward")
return
@@ -617,10 +567,6 @@ def get_action_rewards(self):
def do_try_all_actions(self, args):
"""Tries all actions from this position and reports the results in sorted order by reward"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if not self.env.reward_space:
print("No default reward set. Call set_default_reward")
return
@@ -646,10 +592,6 @@ def do_greedy(self, arg):
An argument, if given, should be the number of steps to take.
The search will try to improve the default reward. Please call set_default_reward if needed.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if not self.env.reward_space:
print("No default reward set. Call set_default_reward")
return
@@ -690,12 +632,8 @@ def do_observation(self, arg):
The name should come from the list of observations printed by the command list_observations.
Tab completion will be used if available.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if arg == "" and self.env.observation_space:
- arg = self.env.observation_space.id
+ arg = self.env.observation_space_spec.id
if self.observations.count(arg):
with Timer() as timer:
@@ -718,10 +656,6 @@ def do_set_default_observation(self, arg):
With no argument it will set to None.
This command will rerun the actions on the stack.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
arg = arg.strip()
if not arg or self.observations.count(arg):
with Timer() as timer:
@@ -746,10 +680,6 @@ def do_reward(self, arg):
The name should come from the list of rewards printed by the command list_rewards.
Tab completion will be used if available.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
if arg == "" and self.env.reward_space:
arg = self.env.reward_space.id
@@ -772,10 +702,6 @@ def do_set_default_reward(self, arg):
With no argument it will set to None.
This command will rerun the actions on the stack.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
arg = arg.strip()
if not arg or self.rewards.count(arg):
with Timer(f"Reward {arg}"):
@@ -791,10 +717,6 @@ def do_commandline(self, arg):
def do_stack(self, arg):
"""Show the environments on the stack. The current environment is the first shown."""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
rows = []
total = 0
for i, hist in enumerate(self.stack):
@@ -821,10 +743,6 @@ def do_simplify_stack(self, arg):
being removed that previously had a negative reward being necessary for a later action to
have a positive reward. This means you might see non-positive rewards on the stack afterwards.
"""
- if not self.env.benchmark:
- print("No benchmark set, please call the set_benchmark command")
- return
-
self.env.reset()
old_stack = self.stack
self.stack = []
@@ -889,15 +807,7 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")
- if FLAGS.ls_benchmark:
- benchmark = benchmark_from_flags()
- env = env_from_flags(benchmark)
- print("\n".join(sorted(env.benchmarks)))
- env.close()
- return
-
with Timer("Initialized environment"):
- # FIXME Chris, I don't seem to actually get a benchmark
benchmark = benchmark_from_flags()
env = env_from_flags(benchmark)
diff --git a/compiler_gym/bin/random_search.py b/compiler_gym/bin/random_search.py
index 1f9ff9cf6..c0e3f4927 100644
--- a/compiler_gym/bin/random_search.py
+++ b/compiler_gym/bin/random_search.py
@@ -17,8 +17,8 @@
.. code-block::
- $ python -m compiler_gym.bin.random_search --env=llvm-ic-v0 --benchmark=cBench-v1/dijkstra --runtime=60
- Started 16 worker threads for benchmark benchmark://cBench-v1/dijkstra (410 instructions) using reward IrInstructionCountOz.
+ $ python -m compiler_gym.bin.random_search --env=llvm-ic-v0 --benchmark=cbench-v1/dijkstra --runtime=60
+ Started 16 worker threads for benchmark benchmark://cbench-v1/dijkstra (410 instructions) using reward IrInstructionCountOz.
=== Running for a minute ===
Runtime: a minute. Num steps: 470,407 (7,780 / sec). Num episodes: 4,616 (76 / sec). Num restarts: 0.
Best reward: 101.59% (96 passes, found after 35 seconds)
@@ -58,7 +58,6 @@
from absl import app, flags
-import compiler_gym.util.flags.ls_benchmark # noqa Flag definition.
import compiler_gym.util.flags.nproc # noqa Flag definition.
import compiler_gym.util.flags.output_dir # noqa Flag definition.
from compiler_gym.random_search import random_search
@@ -93,11 +92,6 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")
- if FLAGS.ls_benchmark:
- env = env_from_flags()
- print("\n".join(sorted(env.benchmarks)))
- env.close()
- return
if FLAGS.ls_reward:
env = env_from_flags()
print("\n".join(sorted(env.reward.indices.keys())))
@@ -112,8 +106,6 @@ def make_env():
env = make_env()
try:
env.reset()
- if not env.benchmark:
- raise app.UsageError("No benchmark specified.")
finally:
env.close()
diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py
index 259d62ecf..ec83b4866 100644
--- a/compiler_gym/bin/service.py
+++ b/compiler_gym/bin/service.py
@@ -20,87 +20,154 @@
.. code-block::
- $ python -m compiler_gym.bin.service --env= [--heading_level=]
+ $ python -m compiler_gym.bin.service --env=
For example:
.. code-block::
$ python -m compiler_gym.bin.service --env=llvm-v0
- # CompilerGym Service `/path/to/compiler_gym/envs/llvm/service/compiler_gym-llvm-service`
- ## Programs
-
- +------------------------+
- | Benchmark |
- +========================+
- | benchmark://npb-v0/1 |
- +------------------------+
+ Datasets
+ --------
+ +----------------------------+--------------------------+------------------------------+
+ | Dataset | Num. Benchmarks [#f1]_ | Description |
+ +============================+==========================+==============================+
+ | benchmark://anghabench-v0 | 1,042,976 | Compile-only C/C++ functions |
+ +----------------------------+--------------------------+------------------------------+
+ | benchmark://blas-v0 | 300 | Basic linear algebra kernels |
+ +----------------------------+--------------------------+------------------------------+
...
- ## Action Spaces
-
-
- ### `PassesAll` (Commandline)
-
- +---------------------------------------+-----------------------------------+-------------------------------+
- | Action | Flag | Description |
- +=======================================+===================================+===============================+
- | AddDiscriminatorsPass | `-add-discriminators` | Add DWARF path discriminators |
- +---------------------------------------+-----------------------------------+-------------------------------+
+ Observation Spaces
+ ------------------
+
+ +--------------------------+----------------------------------------------+
+ | Observation space | Shape |
+ +==========================+==============================================+
+ | Autophase | `Box(0, 9223372036854775807, (56,), int64)` |
+ +--------------------------+----------------------------------------------+
+ | AutophaseDict | `Dict(ArgsPhi:int<0,inf>, BB03Phi:int<0,...` |
+ +--------------------------+----------------------------------------------+
+ | BitcodeFile | `str_list<>[0,4096.0])` |
+ +--------------------------+----------------------------------------------+
...
-The capabilities of an unmanaged service can be queried using the
-:code:`--service` flag. For example, query a service running at
-:code:`localhost:8080`:
+The output is tabular summaries of the environment's datasets, observation
+spaces, reward spaces, and action spaces, using reStructuredText syntax
+(https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html#tables).
+
+To query the capabilities of an unmanaged service, use :code:`--service`. For
+example, query a service running at :code:`localhost:8080` using:
.. code-block::
$ python -m compiler_gym.bin.service --service=localhost:8080
-Or query the capabilities of a binary that implements the RPC service interface
-using:
+To query the capability of a binary that implements the RPC service interface,
+use the :code:`--local_service_binary` flag:
.. code-block::
$ python -m compiler_gym.bin.service --local_service_binary=/path/to/service/binary
"""
+import sys
+from typing import Iterable
+
+import humanize
from absl import app, flags
+from compiler_gym.datasets import Dataset
from compiler_gym.envs import CompilerEnv
from compiler_gym.spaces import Commandline
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.tabulate import tabulate
from compiler_gym.util.truncate import truncate
-flags.DEFINE_integer(
- "heading_level",
- 1,
- "The base level for generated markdown headers, in the range [1,4].",
+flags.DEFINE_string(
+ "heading_underline_char",
+ "-",
+ "The character to repeat to underline headings.",
)
FLAGS = flags.FLAGS
-def header(message: str, level: int):
- prefix = "#" * level
- return f"\n\n{prefix} {message}\n"
+def header(message: str):
+ underline = FLAGS.heading_underline_char * (
+ len(message) // len(FLAGS.heading_underline_char)
+ )
+ return f"\n\n{message}\n{underline}\n"
+
+
+def shape2str(shape, n: int = 80):
+ string = str(shape)
+ if len(string) > n:
+ return f"`{string[:n-4]}` ..."
+ return f"`{string}`"
+
+
+def summarize_datasets(datasets: Iterable[Dataset]) -> str:
+ rows = []
+ # Override the default iteration order of datasets.
+ for dataset in sorted(datasets, key=lambda d: d.name):
+ # Raw numeric values here, formatted below.
+ description = truncate(dataset.description, max_line_len=60)
+ links = ", ".join(
+ f"`{name} <{url}>`__" for name, url in sorted(dataset.references.items())
+ )
+ if links:
+ description = f"{description} [{links}]"
+ rows.append(
+ (
+ dataset.name,
+ dataset.size,
+ description,
+ dataset.validatable,
+ )
+ )
+ rows.append(("Total", sum(r[1] for r in rows), "", ""))
+ return (
+ tabulate(
+ [
+ (
+ n,
+ humanize.intcomma(f) if f >= 0 else "∞",
+ l,
+ v,
+ )
+ for n, f, l, v in rows
+ ],
+ headers=(
+ "Dataset",
+ "Num. Benchmarks [#f1]_",
+ "Description",
+ "Validatable [#f2]_",
+ ),
+ )
+ + f"""
+
+.. [#f1] Values obtained on {sys.platform}. Datasets are platform-specific.
+.. [#f2] A **validatable** dataset is one where the behavior of the benchmarks
+ can be checked by compiling the programs to binaries and executing
+ them. If the benchmarks crash, or are found to have different behavior,
+ then validation fails. This type of validation is used to check that
+ the compiler has not broken the semantics of the program.
+ See :mod:`compiler_gym.bin.validate`.
+"""
+ )
-def print_service_capabilities(env: CompilerEnv, base_heading_level: int = 1):
+def print_service_capabilities(env: CompilerEnv):
"""Discover and print the capabilities of a CompilerGym service.
:param env: An environment.
"""
- print(header(f"CompilerGym Service `{env.service}`", base_heading_level).strip())
- print(header("Programs", base_heading_level + 1))
+ print(header("Datasets"))
print(
- tabulate(
- [(p,) for p in sorted(env.benchmarks)],
- headers=("Benchmark",),
- )
+ summarize_datasets(env.datasets),
)
- print(header("Observation Spaces", base_heading_level + 1))
+ print(header("Observation Spaces"))
print(
tabulate(
sorted(
@@ -112,7 +179,7 @@ def print_service_capabilities(env: CompilerEnv, base_heading_level: int = 1):
headers=("Observation space", "Shape"),
)
)
- print(header("Reward Spaces", base_heading_level + 1))
+ print(header("Reward Spaces"))
print(
tabulate(
[
@@ -135,14 +202,8 @@ def print_service_capabilities(env: CompilerEnv, base_heading_level: int = 1):
)
)
- print(header("Action Spaces", base_heading_level + 1).rstrip())
for action_space in env.action_spaces:
- print(
- header(
- f"`{action_space.name}` ({type(action_space).__name__})",
- base_heading_level + 2,
- )
- )
+ print(header(f"{action_space.name} Action Space"))
# Special handling for commandline action spaces to print additional
# information.
if isinstance(action_space, Commandline):
@@ -167,11 +228,10 @@ def print_service_capabilities(env: CompilerEnv, base_heading_level: int = 1):
def main(argv):
"""Main entry point."""
assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}"
- assert 0 < FLAGS.heading_level <= 4, "--heading_level must be in range [1,4]"
env = env_from_flags()
try:
- print_service_capabilities(env, base_heading_level=FLAGS.heading_level)
+ print_service_capabilities(env)
finally:
env.close()
diff --git a/compiler_gym/bin/validate.py b/compiler_gym/bin/validate.py
index 6fc19b0e6..ba498b349 100644
--- a/compiler_gym/bin/validate.py
+++ b/compiler_gym/bin/validate.py
@@ -10,7 +10,7 @@
$ cat << EOF |
benchmark,reward,walltime,commandline
- cBench-v1/crc32,0,1.2,opt input.bc -o output.bc
+ cbench-v1/crc32,0,1.2,opt input.bc -o output.bc
EOF
python -m compiler_gym.bin.validate --env=llvm-ic-v0 -
@@ -25,21 +25,15 @@
------------
The correct format for generating input states can be generated using
-:func:`env.state.to_csv() `. The
-input CSV must start with a header row. A valid header row can be generated
-using
-:func:`env.state.csv_header() `.
-
-Full example:
-
->>> env = gym.make("llvm-v0")
->>> env.reset()
->>> env.step(0)
->>> print(env.state.csv_header())
-benchmark,reward,walltime,commandline
->>> print(env.state.to_csv())
-benchmark://cBench-v1/rijndael,,20.53565216064453,opt -add-discriminators input.bc -o output.bc
-%
+:class:`CompilerEnvStateWriter `. For
+example:
+
+ >>> env = gym.make("llvm-autophase-ic-v0")
+ >>> env.reset()
+ >>> env.step(env.action_space.sample())
+ >>> with CompilerEnvStateWriter(open("results.csv", "wb")) as writer:
+ ... writer.write_state(env.state)
+
Output Format
-------------
@@ -60,17 +54,15 @@
import json
import re
import sys
-from typing import Iterable
import numpy as np
from absl import app, flags
-import compiler_gym.util.flags.dataset # noqa Flag definition.
import compiler_gym.util.flags.nproc # noqa Flag definition.
-from compiler_gym.envs.compiler_env import CompilerEnvState
+from compiler_gym.compiler_env_state import CompilerEnvState, CompilerEnvStateReader
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.shell_format import emph, plural
-from compiler_gym.util.statistics import geometric_mean
+from compiler_gym.util.statistics import arithmetic_mean, geometric_mean, stdev
from compiler_gym.validate import ValidationResult, validate_states
flags.DEFINE_boolean(
@@ -122,31 +114,10 @@ def to_string(result: ValidationResult, name_col_width: int) -> str:
return f"✅ {name:<{name_col_width}} {result.state.reward:9.4f}"
-def arithmetic_mean(values):
- """Zero-length-safe arithmetic mean."""
- if not values:
- return 0
- return sum(values) / len(values)
-
-
-def stdev(values):
- """Zero-length-safe standard deviation."""
- return np.std(values or [0])
-
-
-def read_states_from_paths(paths: Iterable[str]) -> Iterable[CompilerEnvState]:
- for path in paths:
- if path == "-":
- yield from CompilerEnvState.read_csv_file(sys.stdin)
- else:
- with open(path) as f:
- yield from CompilerEnvState.read_csv_file(f)
-
-
def main(argv):
"""Main entry point."""
try:
- states = list(read_states_from_paths(argv[1:]))
+ states = list(CompilerEnvStateReader.read_paths(argv[1:]))
except ValueError as e:
print(e, file=sys.stderr)
sys.exit(1)
@@ -177,7 +148,6 @@ def main(argv):
validation_results = validate_states(
env_from_flags,
states,
- datasets=FLAGS.dataset,
nproc=FLAGS.nproc,
inorder=FLAGS.inorder,
)
@@ -238,16 +208,16 @@ def progress_message(i):
)
progress_message(len(states))
- json_log = []
+ result_dicts = []
- def dump_json_log():
+ def dump_result_dicst_to_json():
with open(FLAGS.validation_logfile, "w") as f:
- json.dump(json_log, f)
+ json.dump(result_dicts, f)
for i, result in enumerate(validation_results, start=1):
intermediate_print("\r\033[K", to_string(result, name_col_width), sep="")
progress_message(len(states) - i)
- json_log.append(result.json())
+ result_dicts.append(result.dict())
if not result.okay():
error_count += 1
@@ -256,9 +226,9 @@ def dump_json_log():
walltimes.append(result.state.walltime)
if not i % 10:
- dump_json_log()
+ dump_result_dicst_to_json()
- dump_json_log()
+ dump_result_dicst_to_json()
# Print a summary footer.
intermediate_print("\r\033[K----", "-" * name_col_width, "-----------", sep="")
diff --git a/compiler_gym/compiler_env_state.py b/compiler_gym/compiler_env_state.py
index 43bd6dfc3..90c7ea80e 100644
--- a/compiler_gym/compiler_env_state.py
+++ b/compiler_gym/compiler_env_state.py
@@ -4,18 +4,16 @@
# LICENSE file in the root directory of this source tree.
"""This module defines a class to represent a compiler environment state."""
import csv
-from io import StringIO
-from typing import Any, Dict, Iterable, NamedTuple, Optional
+import sys
+from typing import Iterable, List, Optional, TextIO
+from pydantic import BaseModel, Field, validator
-def _to_csv(*columns) -> str:
- buf = StringIO()
- writer = csv.writer(buf)
- writer.writerow(columns)
- return buf.getvalue().rstrip()
+from compiler_gym.datasets.uri import BENCHMARK_URI_PATTERN
+from compiler_gym.util.truncate import truncate
-class CompilerEnvState(NamedTuple):
+class CompilerEnvState(BaseModel):
"""The representation of a compiler environment state.
The state of an environment is defined as a benchmark and a sequence of
@@ -23,88 +21,40 @@ class CompilerEnvState(NamedTuple):
contains the information required to reproduce the result.
"""
- benchmark: str
- """The name of the benchmark used for this episode."""
+ benchmark: str = Field(
+ allow_mutation=False,
+ regex=BENCHMARK_URI_PATTERN,
+ examples=[
+ "benchmark://cbench-v1/crc32",
+ "generator://csmith-v0/0",
+ ],
+ )
+ """The URI of the benchmark used for this episode."""
commandline: str
"""The list of actions that produced this state, as a commandline."""
walltime: float
- """The walltime of the episode."""
+ """The walltime of the episode in seconds. Must be non-negative."""
- reward: Optional[float] = None
- """The cumulative reward for this episode."""
+ reward: Optional[float] = Field(
+ required=False,
+ default=None,
+ allow_mutation=True,
+ )
+ """The cumulative reward for this episode. Optional."""
+
+ @validator("walltime")
+ def walltime_nonnegative(cls, v):
+ if v is not None:
+ assert v >= 0, "Walltime cannot be negative"
+ return v
@property
def has_reward(self) -> bool:
"""Return whether the state has a reward value."""
return self.reward is not None
- @staticmethod
- def csv_header() -> str:
- """Return the header string for the CSV-format.
-
- :return: A comma-separated string.
- """
- return _to_csv("benchmark", "reward", "walltime", "commandline")
-
- def json(self):
- """Return the state as JSON."""
- return self._asdict() # pylint: disable=no-member
-
- def to_csv(self) -> str:
- """Serialize a state to a comma separated list of values.
-
- :return: A comma-separated string.
- """
- return _to_csv(self.benchmark, self.reward, self.walltime, self.commandline)
-
- @classmethod
- def from_json(cls, data: Dict[str, Any]) -> "CompilerEnvState":
- """Construct a state from a JSON dictionary."""
- return cls(**data)
-
- @classmethod
- def from_csv(cls, csv_string: str) -> "CompilerEnvState":
- """Construct a state from a comma separated list of values."""
- reader = csv.reader(StringIO(csv_string))
- for line in reader:
- try:
- benchmark, reward, walltime, commandline = line
- break
- except ValueError as e:
- raise ValueError(f"Failed to parse input: `{csv_string}`: {e}") from e
- else:
- raise ValueError(f"Failed to parse input: `{csv_string}`")
- return cls(
- benchmark=benchmark,
- reward=None if reward == "" else float(reward),
- walltime=0 if walltime == "" else float(walltime),
- commandline=commandline,
- )
-
- @classmethod
- def read_csv_file(cls, in_file) -> Iterable["CompilerEnvState"]:
- """Read states from a CSV file.
-
- :param in_file: A file object.
- :returns: A generator of :class:`CompilerEnvState` instances.
- :raises ValueError: If input parsing fails.
- """
- # TODO(cummins): Check schema of DictReader and, on failure, fallback
- # to from_csv() per-line.
- # TODO(cummins): Accept a URL for in_file and read from web.
- data = in_file.readlines()
- for line in csv.DictReader(data):
- try:
- line["reward"] = float(line["reward"]) if line.get("reward") else None
- line["walltime"] = (
- float(line["walltime"]) if line.get("walltime") else None
- )
- yield CompilerEnvState(**line)
- except (TypeError, KeyError) as e:
- raise ValueError(f"Failed to parse input: `{e}`") from e
-
def __eq__(self, rhs) -> bool:
if not isinstance(rhs, CompilerEnvState):
return False
@@ -125,3 +75,155 @@ def __eq__(self, rhs) -> bool:
def __ne__(self, rhs) -> bool:
return not self == rhs
+
+ class Config:
+ validate_assignment = True
+
+
+class CompilerEnvStateWriter:
+ """Serialize compiler environment states to CSV.
+
+ Example use:
+
+ >>> with CompilerEnvStateWriter(open("results.csv", "wb")) as writer:
+ ... writer.write_state(env.state)
+ """
+
+ def __init__(self, f: TextIO, header: bool = True):
+ """Constructor.
+
+ :param f: The file to write to.
+ :param header: Whether to include a header row.
+ """
+ self.f = f
+ self.writer = csv.writer(self.f, lineterminator="\n")
+ self.header = header
+
+ def write_state(self, state: CompilerEnvState, flush: bool = False) -> None:
+ """Write the state to file.
+
+ :param state: A compiler environment state.
+
+ :param flush: Write to file immediately.
+ """
+ if self.header:
+ self.writer.writerow(("benchmark", "reward", "walltime", "commandline"))
+ self.header = False
+ self.writer.writerow(
+ (state.benchmark, state.reward, state.walltime, state.commandline)
+ )
+ if flush:
+ self.f.flush()
+
+ def __enter__(self):
+ """Support with-statement for the writer."""
+ return self
+
+ def __exit__(self, *args):
+ """Support with-statement for the writer."""
+ self.f.close()
+
+
+class CompilerEnvStateReader:
+ """Read states from a CSV file.
+
+ Example usage:
+
+ >>> with CompilerEnvStateReader(open("results.csv", "rb")) as reader:
+ ... for state in reader:
+ ... print(state)
+ """
+
+ def __init__(self, f: TextIO):
+ """Constructor.
+
+ :param f: The file to read.
+ """
+ self.f = f
+ self.reader = csv.reader(self.f)
+
+ def __iter__(self) -> Iterable[CompilerEnvState]:
+ """Read the states from the file."""
+ columns_in_order = ["benchmark", "reward", "walltime", "commandline"]
+ # Read the CSV and coerce the columns into the expected order.
+ for (
+ benchmark,
+ reward,
+ walltime,
+ commandline,
+ ) in self._iterate_columns_in_order(self.reader, columns_in_order):
+ yield CompilerEnvState(
+ benchmark=benchmark,
+ reward=None if reward == "" else float(reward),
+ walltime=0 if walltime == "" else float(walltime),
+ commandline=commandline,
+ )
+
+ @staticmethod
+ def _iterate_columns_in_order(
+ reader: csv.reader, columns: List[str]
+ ) -> Iterable[List[str]]:
+ """Read the input CSV and return each row in the given column order.
+
+ Supports CSVs both with and without a header. If no header, columns are
+ expected to be in the correct order. Else the header row is used to
+ determine column order.
+
+ Header row detection is case insensitive.
+
+ :param reader: The CSV file to read.
+
+ :param columns: A list of column names in the order that they are
+ expected.
+
+ :return: An iterator over rows.
+ """
+ try:
+ row = next(reader)
+ except StopIteration:
+ # Empty file.
+ return
+
+ if len(row) != len(columns):
+ raise ValueError(
+ f"Expected {len(columns)} columns in the first row of CSV: {truncate(row)}"
+ )
+
+ # Convert the maybe-header columns to lowercase for case-insensitive
+ # comparison.
+ maybe_header = [v.lower() for v in row]
+ if set(maybe_header) == set(columns):
+ # The first row matches the expected columns names, so use it to
+ # determine the column order.
+ column_order = [maybe_header.index(v) for v in columns]
+ yield from ([row[v] for v in column_order] for row in reader)
+ else:
+ # The first row isn't a header, so assume that all rows are in
+ # expected column order.
+ yield row
+ yield from reader
+
+ def __enter__(self):
+ """Support with-statement for the reader."""
+ return self
+
+ def __exit__(self, *args):
+ """Support with-statement for the reader."""
+ self.f.close()
+
+ @staticmethod
+ def read_paths(paths: Iterable[str]) -> Iterable[CompilerEnvState]:
+ """Read a states from a list of file paths.
+
+ Read states from stdin using a special path :code:`"-"`.
+
+ :param: A list of paths.
+
+ :return: A generator of compiler env states.
+ """
+ for path in paths:
+ if path == "-":
+ yield from iter(CompilerEnvStateReader(sys.stdin))
+ else:
+ with open(path) as f:
+ yield from iter(CompilerEnvStateReader(f))
diff --git a/compiler_gym/datasets/BUILD b/compiler_gym/datasets/BUILD
index 6bd60d690..994045993 100644
--- a/compiler_gym/datasets/BUILD
+++ b/compiler_gym/datasets/BUILD
@@ -6,18 +6,25 @@ load("@rules_python//python:defs.bzl", "py_library")
py_library(
name = "datasets",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "benchmark.py",
+ "dataset.py",
+ "datasets.py",
+ "files_dataset.py",
+ "tar_dataset.py",
+ ],
visibility = ["//visibility:public"],
deps = [
- ":dataset",
+ ":uri",
+ "//compiler_gym:validation_result",
+ "//compiler_gym/service/proto",
+ "//compiler_gym/util",
],
)
py_library(
- name = "dataset",
- srcs = ["dataset.py"],
+ name = "uri",
+ srcs = ["uri.py"],
visibility = ["//compiler_gym:__subpackages__"],
- deps = [
- "//compiler_gym/util",
- ],
)
diff --git a/compiler_gym/datasets/__init__.py b/compiler_gym/datasets/__init__.py
index b0dc9440c..df9b009a4 100644
--- a/compiler_gym/datasets/__init__.py
+++ b/compiler_gym/datasets/__init__.py
@@ -3,12 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Manage datasets of benchmarks."""
+from compiler_gym.datasets.benchmark import (
+ Benchmark,
+ BenchmarkInitError,
+ BenchmarkSource,
+)
from compiler_gym.datasets.dataset import (
- LegacyDataset,
+ Dataset,
+ DatasetInitError,
activate,
deactivate,
delete,
require,
)
+from compiler_gym.datasets.datasets import Datasets
+from compiler_gym.datasets.files_dataset import FilesDataset
+from compiler_gym.datasets.tar_dataset import TarDataset, TarDatasetWithManifest
-__all__ = ["LegacyDataset", "require", "activate", "deactivate", "delete"]
+__all__ = [
+ "activate",
+ "Benchmark",
+ "BenchmarkInitError",
+ "BenchmarkSource",
+ "Dataset",
+ "DatasetInitError",
+ "Datasets",
+ "deactivate",
+ "delete",
+ "FilesDataset",
+ "require",
+ "TarDataset",
+ "TarDatasetWithManifest",
+]
diff --git a/compiler_gym/datasets/benchmark.py b/compiler_gym/datasets/benchmark.py
new file mode 100644
index 000000000..7a50c21c2
--- /dev/null
+++ b/compiler_gym/datasets/benchmark.py
@@ -0,0 +1,339 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from concurrent.futures import as_completed
+from pathlib import Path
+from typing import Callable, Iterable, List, NamedTuple, Optional, Union
+
+from compiler_gym.service.proto import Benchmark as BenchmarkProto
+from compiler_gym.service.proto import File
+from compiler_gym.util import thread_pool
+from compiler_gym.util.decorators import memoized_property
+from compiler_gym.validation_error import ValidationError
+
+# A validation callback is a function that takes a single CompilerEnv instance
+# as its argument and returns an iterable sequence of zero or more
+# ValidationError tuples.
+ValidationCallback = Callable[["CompilerEnv"], Iterable[ValidationError]] # noqa: F821
+
+
+class BenchmarkSource(NamedTuple):
+ """A source file that is used to generate a benchmark. A benchmark may
+ comprise many source files.
+
+ .. warning::
+
+ The :class:`BenchmarkSource `
+ class is new and is likely to change in the future.
+ """
+
+ filename: str
+ """The name of the file."""
+
+ contents: bytes
+ """The contents of the file as a byte array."""
+
+ def __repr__(self) -> str:
+ return str(self.filename)
+
+
+class Benchmark(object):
+ """A benchmark represents a particular program that is being compiled.
+
+ A benchmark is a program that can be used by a :class:`CompilerEnv
+ ` as a program to optimize. A benchmark
+ comprises the data that is fed into the compiler, identified by a URI.
+
+ Benchmarks are not normally instantiated directly. Instead, benchmarks are
+ instantiated using :meth:`env.datasets.benchmark(uri)
+ `:
+
+ >>> env.datasets.benchmark("benchmark://npb-v0/20")
+ benchmark://npb-v0/20
+
+ The available benchmark URIs can be queried using
+ :meth:`env.datasets.benchmark_uris()
+ `.
+
+ >>> next(env.datasets.benchmark_uris())
+ 'benchmark://cbench-v1/adpcm'
+
+ Compiler environments may provide additional helper functions for generating
+ benchmarks, such as :meth:`env.make_benchmark()
+ ` for LLVM.
+
+ A Benchmark instance wraps an instance of the :code:`Benchmark` protocol
+ buffer from the `RPC interface
+ `_
+ with additional functionality. The data underlying benchmarks should be
+ considered immutable. New attributes cannot be assigned to Benchmark
+ instances.
+
+ The benchmark for an environment can be set during :meth:`env.reset()
+ `. The currently active benchmark can
+ be queried using :attr:`env.benchmark
+ `:
+
+ >>> env = gym.make("llvm-v0")
+ >>> env.reset(benchmark="benchmark://cbench-v1/crc32")
+ >>> env.benchmark
+ benchmark://cbench-v1/crc32
+
+
+ """
+
+ __slots__ = ["_proto", "_validation_callbacks", "_sources"]
+
+ def __init__(
+ self,
+ proto: BenchmarkProto,
+ validation_callbacks: Optional[List[ValidationCallback]] = None,
+ sources: Optional[List[BenchmarkSource]] = None,
+ ):
+ self._proto = proto
+ self._validation_callbacks = validation_callbacks or []
+ self._sources = list(sources or [])
+
+ def __repr__(self) -> str:
+ return str(self.uri)
+
+ @property
+ def uri(self) -> str:
+ """The URI of the benchmark.
+
+ Benchmark URIs should be unique, that is, that two URIs with the same
+ value should resolve to the same benchmark. However, URIs do not have
+ uniquely describe a benchmark. That is, multiple identical benchmarks
+ could have different URIs.
+
+ :return: A URI string. :type: string
+ """
+ return self._proto.uri
+
+ @property
+ def proto(self) -> BenchmarkProto:
+ """The protocol buffer representing the benchmark.
+
+ :return: A Benchmark message.
+ :type: :code:`Benchmark`
+ """
+ return self._proto
+
+ @property
+ def sources(self) -> Iterable[BenchmarkSource]:
+ """The original source code used to produce this benchmark, as a list of
+ :class:`BenchmarkSource `
+ instances.
+
+ :return: A sequence of source files.
+
+ :type: :code:`Iterable[BenchmarkSource]`
+
+ .. warning::
+
+ The :meth:`Benchmark.sources
+ ` property is new and is
+ likely to change in the future.
+ """
+ return (BenchmarkSource(*x) for x in self._sources)
+
+ def is_validatable(self) -> bool:
+ """Whether the benchmark has any validation callbacks registered.
+
+ :return: :code:`True` if the benchmark has at least one validation
+ callback.
+ """
+ return self._validation_callbacks != []
+
+ def validate(self, env: "CompilerEnv") -> List[ValidationError]: # noqa: F821
+ """Run the validation callbacks and return any errors.
+
+ If no errors are returned, validation has succeeded:
+
+ >>> benchmark.validate(env)
+ []
+
+ If an error occurs, a :class:`ValidationError
+ ` tuple will describe the type of the
+ error, and optionally contain other data:
+
+ >>> benchmark.validate(env)
+ [ValidationError(type="RuntimeError")]
+
+ Multiple :class:`ValidationError ` errors
+ may be returned to indicate multiple errors.
+
+ This is a synchronous version of :meth:`ivalidate()
+ ` that blocks until all
+ results are ready:
+
+ >>> benchmark.validate(env) == list(benchmark.ivalidate(env))
+ True
+
+ :param env: The :class:`CompilerEnv `
+ instance that is being validated.
+
+ :return: A list of zero or more :class:`ValidationError
+ ` tuples that occurred during
+ validation.
+ """
+ return list(self.ivalidate(env))
+
+ def ivalidate(self, env: "CompilerEnv") -> Iterable[ValidationError]: # noqa: F821
+ """Run the validation callbacks and return a generator of errors.
+
+ This is an asynchronous version of :meth:`validate()
+ ` that returns immediately.
+
+ :parameter env: A :class:`CompilerEnv `
+ instance to validate.
+
+ :return: A generator of :class:`ValidationError
+ ` tuples that occur during validation.
+ """
+ executor = thread_pool.get_thread_pool_executor()
+ futures = (
+ executor.submit(validator, env) for validator in self.validation_callbacks()
+ )
+ for future in as_completed(futures):
+ result: Iterable[ValidationError] = future.result()
+ if result:
+ yield from result
+
+ def validation_callbacks(
+ self,
+ ) -> List[ValidationCallback]:
+ """Return the list of registered validation callbacks.
+
+ :return: A list of callables. See :meth:`add_validation_callback()
+ `.
+ """
+ return self._validation_callbacks
+
+ def add_source(self, source: BenchmarkSource) -> None:
+ """Register a new source file for this benchmark.
+
+ :param source: The :class:`BenchmarkSource
+ ` to register.
+ """
+ self._sources.append(source)
+
+ def add_validation_callback(
+ self,
+ validation_callback: ValidationCallback,
+ ) -> None:
+ """Register a new validation callback that will be executed on
+ :meth:`validate() `.
+
+ :param validation_callback: A callback that accepts a single
+ :class:`CompilerEnv ` argument and
+ returns an iterable sequence of zero or more :class:`ValidationError
+ ` tuples. Validation callbacks must be
+ thread safe and must not modify the environment.
+ """
+ self._validation_callbacks.append(validation_callback)
+
+ def write_sources_to_directory(self, directory: Path) -> int:
+ """Write the source files for this benchmark to the given directory.
+
+ This writes each of the :attr:`benchmark.sources
+ ` files to disk.
+
+ If the benchmark has no sources, no files are written.
+
+ :param directory: The directory to write results to. If it does not
+ exist, it is created.
+
+ :return: The number of files written.
+ """
+ directory = Path(directory)
+ directory.mkdir(exist_ok=True, parents=True)
+ uniq_paths = set()
+ for filename, contents in self.sources:
+ path = directory / filename
+ uniq_paths.add(path)
+ path.parent.mkdir(exist_ok=True, parents=True)
+ with open(path, "wb") as f:
+ f.write(contents)
+
+ return len(uniq_paths)
+
+ @classmethod
+ def from_file(cls, uri: str, path: Path):
+ """Construct a benchmark from a file.
+
+ :param uri: The URI of the benchmark.
+
+ :param path: A filesystem path.
+
+ :raise FileNotFoundError: If the path does not exist.
+
+ :return: A :class:`Benchmark `
+ instance.
+ """
+ path = Path(path)
+ if not path.is_file():
+ raise FileNotFoundError(path)
+ return cls(
+ proto=BenchmarkProto(
+ uri=uri, program=File(uri=f"file:///{path.absolute()}")
+ ),
+ )
+
+ @classmethod
+ def from_file_contents(cls, uri: str, data: bytes):
+ """Construct a benchmark from raw data.
+
+ :param uri: The URI of the benchmark.
+
+ :param data: An array of bytes that will be passed to the compiler
+ service.
+ """
+ return cls(proto=BenchmarkProto(uri=uri, program=File(contents=data)))
+
+ def __eq__(self, other: Union[str, "Benchmark"]):
+ if isinstance(other, Benchmark):
+ return self.uri == other.uri
+ else:
+ return self.uri == other
+
+ def __ne__(self, other: Union[str, "Benchmark"]):
+ return not self == other
+
+
+class BenchmarkInitError(OSError):
+ """Base class for errors raised if a benchmark fails to initialize."""
+
+
+class BenchmarkWithSource(Benchmark):
+ """A benchmark which has a single source file."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._src_name = None
+ self._src_path = None
+
+ @classmethod
+ def create(
+ cls, uri: str, input_path: Path, src_name: str, src_path: Path
+ ) -> Benchmark:
+ """Create a benchmark from paths."""
+ benchmark = cls.from_file(uri, input_path)
+ benchmark._src_name = src_name # pylint: disable=protected-access
+ benchmark._src_path = src_path # pylint: disable=protected-access
+ return benchmark
+
+ @memoized_property
+ def sources( # pylint: disable=invalid-overridden-method
+ self,
+ ) -> Iterable[BenchmarkSource]:
+ with open(self._src_path, "rb") as f:
+ return [
+ BenchmarkSource(filename=self._src_name, contents=f.read()),
+ ]
+
+ @property
+ def source(self) -> str:
+ """Return the single source file contents as a string."""
+ return list(self.sources)[0].contents.decode("utf-8")
diff --git a/compiler_gym/datasets/dataset.py b/compiler_gym/datasets/dataset.py
index 72608d63e..2045a42b5 100644
--- a/compiler_gym/datasets/dataset.py
+++ b/compiler_gym/datasets/dataset.py
@@ -2,280 +2,469 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import io
-import json
+import logging
import os
import shutil
-import tarfile
import warnings
from pathlib import Path
-from typing import List, NamedTuple, Optional, Union
+from typing import Dict, Iterable, Optional, Union
-import fasteners
-from deprecated.sphinx import deprecated
+from deprecated.sphinx import deprecated as mark_deprecated
-from compiler_gym.util.download import download
+from compiler_gym.datasets.benchmark import Benchmark
+from compiler_gym.datasets.uri import DATASET_NAME_RE
+from compiler_gym.util.debug_util import get_logging_level
-class LegacyDataset(NamedTuple):
- """A collection of benchmarks for use by an environment.
+class Dataset(object):
+ """A dataset is a collection of benchmarks.
- .. deprecated:: 0.1.4
- The next release of CompilerGym will introduce a new API for describing
- datasets with extended functionality. See
- `here `_ for
- more information.
+ The Dataset class has methods for installing and managing groups of
+ benchmarks, for listing the available benchmark URIs, and for instantiating
+ :class:`Benchmark ` objects.
+
+ The Dataset class is an abstract base for implementing datasets. At a
+ minimum, subclasses must implement the :meth:`benchmark()
+ ` and :meth:`benchmark_uris()
+ ` methods, and :meth:`size
+ `. Other methods such as
+ :meth:`install() ` may be used where
+ helpful.
"""
- name: str
- """The name of the dataset."""
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ license: str, # pylint: disable=redefined-builtin
+ site_data_base: Path,
+ benchmark_class=Benchmark,
+ references: Optional[Dict[str, str]] = None,
+ deprecated: Optional[str] = None,
+ sort_order: int = 0,
+ logger: Optional[logging.Logger] = None,
+ validatable: str = "No",
+ ):
+ """Constructor.
+
+ :param name: The name of the dataset. Must conform to the pattern
+ :code:`{{protocol}}://{{name}}-v{{version}}`.
+
+ :param description: A short human-readable description of the dataset.
+
+ :param license: The name of the dataset's license.
+
+ :param site_data_base: The base path of a directory that will be used to
+ store installed files.
+
+ :param benchmark_class: The class to use when instantiating benchmarks.
+ It must have the same constructor signature as :class:`Benchmark
+ `.
+
+ :param references: A dictionary of useful named URLs for this dataset
+ containing extra information, download links, papers, etc.
+
+ :param deprecated: Mark the dataset as deprecated and issue a warning
+ when :meth:`install() `,
+ including the given method. Deprecated datasets are excluded from
+ the :meth:`datasets() `
+ iterator by default.
+
+ :param sort_order: An optional numeric value that should be used to
+ order this dataset relative to others. Lowest value sorts first.
+
+ :param validatable: Whether the dataset is validatable. A validatable
+ dataset is one where the behavior of the benchmarks can be checked
+ by compiling the programs to binaries and executing them. If the
+ benchmarks crash, or are found to have different behavior, then
+ validation fails. This type of validation is used to check that the
+ compiler has not broken the semantics of the program. This value
+ takes a string and is used for documentation purposes only.
+ Suggested values are "Yes", "No", or "Partial".
+
+ :raises ValueError: If :code:`name` does not match the expected type.
+ """
+ self._name = name
+ components = DATASET_NAME_RE.match(name)
+ if not components:
+ raise ValueError(
+ f"Invalid dataset name: '{name}'. "
+ "Dataset name must be in the form: '{{protocol}}://{{name}}-v{{version}}'"
+ )
+ self._description = description
+ self._license = license
+ self._protocol = components.group("dataset_protocol")
+ self._version = int(components.group("dataset_version"))
+ self._references = references or {}
+ self._deprecation_message = deprecated
+ self._validatable = validatable
- license: str
- """The license of the dataset."""
+ self._logger = logger
+ self.sort_order = sort_order
+ self.benchmark_class = benchmark_class
- file_count: int
- """The number of files in the unpacked dataset."""
+ # Set up the site data name.
+ basename = components.group("dataset_name")
+ self._site_data_path = Path(site_data_base).resolve() / self.protocol / basename
- size_bytes: int
- """The size of the unpacked dataset in bytes."""
+ def __repr__(self):
+ return self.name
- url: str = ""
- """A URL where the dataset can be downloaded from. May be an empty string."""
+ @property
+ def logger(self) -> logging.Logger:
+ """The logger for this dataset.
- sha256: str = ""
- """The sha256 checksum of the dataset archive. If provided, this is used to
- verify the contents of the dataset upon download.
- """
+ :type: logging.Logger
+ """
+ # NOTE(cummins): Default logger instantiation is deferred since in
+ # Python 3.6 the Logger instance contains an un-pickle-able Rlock()
+ # which can prevent gym.make() from working. This is a workaround that
+ # can be removed once Python 3.6 support is dropped.
+ if self._logger is None:
+ self._logger = logging.getLogger("compiler_gym.datasets")
+ self._logger.setLevel(get_logging_level())
+ return self._logger
+
+ @property
+ def name(self) -> str:
+ """The name of the dataset.
+
+ :type: str
+ """
+ return self._name
+
+ @property
+ def description(self) -> str:
+ """A short human-readable description of the dataset.
+
+ :type: str
+ """
+ return self._description
- compiler: str = ""
- """The name of the compiler that this dataset supports."""
+ @property
+ def license(self) -> str:
+ """The name of the license of the dataset.
+
+ :type: str
+ """
+ return self._license
+
+ @property
+ def protocol(self) -> str:
+ """The URI protocol that is used to identify benchmarks in this dataset.
+
+ :type: str
+ """
+ return self._protocol
- description: str = ""
- """An optional human-readable description of the dataset."""
+ @property
+ def version(self) -> int:
+ """The version tag for this dataset.
+
+ :type: int
+ """
+ return self._version
- platforms: List[str] = ["macos", "linux"]
- """A list of platforms supported by this dataset. Allowed platforms 'macos' and 'linux'."""
+ @property
+ def references(self) -> Dict[str, str]:
+ """A dictionary of useful named URLs for this dataset containing extra
+ information, download links, papers, etc.
- deprecated_since: str = ""
- """The CompilerGym release in which this dataset was deprecated."""
+ For example:
+
+ >>> dataset.references
+ {'Paper': 'https://arxiv.org/pdf/1407.3487.pdf',
+ 'Homepage': 'https://ctuning.org/wiki/index.php/CTools:CBench'}
+
+ :type: Dict[str, str]
+ """
+ return self._references
@property
def deprecated(self) -> bool:
- """Whether the dataset is deprecated."""
- return bool(self.deprecated_since)
+ """Whether the dataset is included in the iterable sequence of datasets
+ of a containing :class:`Datasets `
+ collection.
- @classmethod
- def from_json_file(cls, path: Path) -> "LegacyDataset":
- """Construct a dataset form a JSON metadata file.
+ :type: bool
+ """
+ return self._deprecation_message is not None
+
+ @property
+ def validatable(self) -> str:
+ """Whether the dataset is validatable. A validatable dataset is one
+ where the behavior of the benchmarks can be checked by compiling the
+ programs to binaries and executing them. If the benchmarks crash, or are
+ found to have different behavior, then validation fails. This type of
+ validation is used to check that the compiler has not broken the
+ semantics of the program.
+
+ This property takes a string and is used for documentation purposes
+ only. Suggested values are "Yes", "No", or "Partial".
+
+ :type: str
+ """
+ return self._validatable
+
+ @property
+ def site_data_path(self) -> Path:
+ """The filesystem path used to store persistent dataset files.
+
+ This directory may not exist.
- :param path: Path of the JSON metadata.
- :return: A LegacyDataset instance.
+ :type: Path
"""
- try:
- with open(str(path), "rb") as f:
- data = json.load(f)
- except json.decoder.JSONDecodeError as e:
- raise OSError(
- f"Failed to read dataset metadata file:\n"
- f"Path: {path}\n"
- f"Error: {e}"
+ return self._site_data_path
+
+ @property
+ def site_data_size_in_bytes(self) -> int:
+ """The total size of the on-disk data used by this dataset.
+
+ :type: int
+ """
+ if not self.site_data_path.is_dir():
+ return 0
+
+ total_size = 0
+ for dirname, _, filenames in os.walk(self.site_data_path):
+ total_size += sum(
+ os.path.getsize(os.path.join(dirname, f)) for f in filenames
)
- return cls(**data)
+ return total_size
+
+ # We use Union[int, float] to represent the size because infinite size is
+ # represented by math.inf, which is a float. For all other sizes this should
+ # be an int.
+ @property
+ def size(self) -> Union[int, float]:
+ """The number of benchmarks in the dataset. If the number of benchmarks
+ is unbounded, for example because the dataset represents a program
+ generator that can produce an infinite number of programs, the value is
+ :code:`math.inf`.
- def to_json_file(self, path: Path) -> Path:
- """Write the dataset metadata to a JSON file.
+ :type: Union[int, float]
+ """
+ return 0
+
+ def __len__(self) -> Union[int, float]:
+ """The number of benchmarks in the dataset.
+
+ This is the same as :meth:`Dataset.size
+ `:
+
+ >>> len(dataset) == dataset.size
+ True
+
+ :return: An integer, or :code:`math.float`.
+ """
+ return self.size
+
+ @property
+ def installed(self) -> bool:
+ """Whether the dataset is installed locally. Installation occurs
+ automatically on first use, or by calling :meth:`install()
+ `.
- :param path: Path of the file to write.
- :return: The path of the written file.
+ :type: bool
"""
- with open(str(path), "wb") as f:
- json.dump(self._asdict(), f)
- return path
+ return True
+
+ def install(self) -> None:
+ """Install this dataset locally.
+
+ Implementing this method is optional. If implementing this method, you
+ must call :code:`super().install()` first.
+ This method should not perform redundant work. This method should first
+ detect whether any work needs to be done so that repeated calls to
+ :code:`install()` will complete quickly.
+ """
+ if self.deprecated:
+ warnings.warn(
+ f"Dataset '{self.name}' is marked as deprecated. {self._deprecation_message}",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
-@deprecated(
+ def uninstall(self) -> None:
+ """Remove any local data for this benchmark.
+
+ This method undoes the work of :meth:`install()
+ `. The dataset can still be used
+ after calling this method.
+ """
+ if self.site_data_path.is_dir():
+ shutil.rmtree(self.site_data_path)
+
+ def benchmarks(self) -> Iterable[Benchmark]:
+ """Enumerate the (possibly infinite) benchmarks lazily.
+
+ Iteration order is consistent across runs. The order of
+ :meth:`benchmarks() ` and
+ :meth:`benchmark_uris() `
+ is the same.
+
+ If the number of benchmarks in the dataset is infinite
+ (:code:`len(dataset) == math.inf`), the iterable returned by this method
+ will continue indefinitely.
+
+ :return: An iterable sequence of :class:`Benchmark
+ ` instances.
+ """
+ # Default implementation. Subclasses may wish to provide an alternative
+ # implementation that is optimized to specific use cases.
+ yield from (self.benchmark(uri) for uri in self.benchmark_uris())
+
+ def __iter__(self) -> Iterable[Benchmark]:
+ """Enumerate the (possibly infinite) benchmarks lazily.
+
+ This is the same as :meth:`Dataset.benchmarks()
+ `:
+
+ >>> from itertools import islice
+ >>> list(islice(dataset, 100)) == list(islice(datset.benchmarks(), 100))
+ True
+
+ :return: An iterable sequence of :meth:`Benchmark
+ ` instances.
+ """
+ yield from self.benchmarks()
+
+ def benchmark_uris(self) -> Iterable[str]:
+ """Enumerate the (possibly infinite) benchmark URIs.
+
+ Iteration order is consistent across runs. The order of
+ :meth:`benchmarks() ` and
+ :meth:`benchmark_uris() `
+ is the same.
+
+ If the number of benchmarks in the dataset is infinite
+ (:code:`len(dataset) == math.inf`), the iterable returned by this method
+ will continue indefinitely.
+
+ :return: An iterable sequence of benchmark URI strings.
+ """
+ raise NotImplementedError("abstract class")
+
+ def benchmark(self, uri: str) -> Benchmark:
+ """Select a benchmark.
+
+ :param uri: The URI of the benchmark to return.
+
+ :return: A :class:`Benchmark `
+ instance.
+
+ :raise LookupError: If :code:`uri` is not found.
+ """
+ raise NotImplementedError("abstract class")
+
+ def __getitem__(self, uri: str) -> Benchmark:
+ """Select a benchmark by URI.
+
+ This is the same as :meth:`Dataset.benchmark(uri)
+ `:
+
+ >>> dataset["benchmark://cbench-v1/crc32"] == dataset.benchmark("benchmark://cbench-v1/crc32")
+ True
+
+ :return: A :class:`Benchmark `
+ instance.
+
+ :raise LookupError: If :code:`uri` does not exist.
+ """
+ return self.benchmark(uri)
+
+
+class DatasetInitError(OSError):
+ """Base class for errors raised if a dataset fails to initialize."""
+
+
+@mark_deprecated(
version="0.1.4",
reason=(
- "Activating datasets will be removed in v0.1.5. "
+ "Datasets are now automatically activated. "
"`More information `_."
),
)
-def activate(env, name: str) -> bool:
- """Move a directory from the inactive to active benchmark directory.
+def activate(env, dataset: Union[str, Dataset]) -> bool:
+ """Deprecated function for managing datasets.
+
+ :param dataset: The name of the dataset to download, or a :class:`Dataset
+ ` instance.
- :param: The name of a dataset.
:return: :code:`True` if the dataset was activated, else :code:`False` if
already active.
+
:raises ValueError: If there is no dataset with that name.
"""
- with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"):
- if (env.datasets_site_path / name).exists():
- # There is already an active benchmark set with this name.
- return False
- if not (env.inactive_datasets_site_path / name).exists():
- raise ValueError(f"Inactive dataset not found: {name}")
- os.rename(env.inactive_datasets_site_path / name, env.datasets_site_path / name)
- os.rename(
- env.inactive_datasets_site_path / f"{name}.json",
- env.datasets_site_path / f"{name}.json",
- )
- return True
+ return False
-@deprecated(
+@mark_deprecated(
version="0.1.4",
reason=(
- "Deleting datasets will be removed in v0.1.5. "
+ "Please use :meth:`del env.datasets[dataset] `. "
"`More information `_."
),
)
-def delete(env, name: str) -> bool:
- """Delete a directory in the inactive benchmark directory.
+def delete(env, dataset: Union[str, Dataset]) -> bool:
+ """Deprecated function for managing datasets.
+
+ Please use :meth:`del env.datasets[dataset]
+ `.
+
+ :param dataset: The name of the dataset to download, or a :class:`Dataset
+ ` instance.
- :param: The name of a dataset.
:return: :code:`True` if the dataset was deleted, else :code:`False` if
already deleted.
"""
- with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"):
- deleted = False
- if (env.datasets_site_path / name).exists():
- shutil.rmtree(str(env.datasets_site_path / name))
- os.unlink(str(env.datasets_site_path / f"{name}.json"))
- deleted = True
- if (env.inactive_datasets_site_path / name).exists():
- shutil.rmtree(str(env.inactive_datasets_site_path / name))
- os.unlink(str(env.inactive_datasets_site_path / f"{name}.json"))
- deleted = True
- return deleted
-
-
-@deprecated(
+ del env.datasets[dataset]
+ return False
+
+
+@mark_deprecated(
version="0.1.4",
reason=(
- "Deactivating datasets will be removed in v0.1.5. "
+ "Please use :meth:`env.datasets.deactivate() `. "
"`More information `_."
),
)
-def deactivate(env, name: str) -> bool:
- """Move a directory from active to the inactive benchmark directory.
+def deactivate(env, dataset: Union[str, Dataset]) -> bool:
+ """Deprecated function for managing datasets.
+
+ Please use :meth:`del env.datasets[dataset]
+ `.
+
+ :param dataset: The name of the dataset to download, or a :class:`Dataset
+ ` instance.
- :param: The name of a dataset.
:return: :code:`True` if the dataset was deactivated, else :code:`False` if
already inactive.
"""
- with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"):
- if not (env.datasets_site_path / name).exists():
- return False
- os.rename(env.datasets_site_path / name, env.inactive_datasets_site_path / name)
- os.rename(
- env.datasets_site_path / f"{name}.json",
- env.inactive_datasets_site_path / f"{name}.json",
- )
- return True
+ del env.datasets[dataset]
+ return False
-def require(env, dataset: Union[str, LegacyDataset]) -> bool:
- """Require that the given dataset is available to the environment.
+@mark_deprecated(
+ version="0.1.7",
+ reason=(
+ "Datasets are now installed automatically, there is no need to call :code:`require()`. "
+ "`More information `_."
+ ),
+)
+def require(env, dataset: Union[str, Dataset]) -> bool:
+ """Deprecated function for managing datasets.
- This will download and activate the dataset if it is not already installed.
- After calling this function, benchmarks from the dataset will be available
- to use.
+ Datasets are now installed automatically. See :class:`env.datasets
+ `.
- Example usage:
+ :param env: The environment that this dataset is required for.
- >>> env = gym.make("llvm-v0")
- >>> require(env, "blas-v0")
- >>> env.reset(benchmark="blas-v0/1")
+ :param dataset: The name of the dataset to download, or a :class:`Dataset
+ ` instance.
- :param env: The environment that this dataset is required for.
- :param dataset: The name of the dataset to download, the URL of the dataset,
- or a :class:`LegacyDataset` instance.
:return: :code:`True` if the dataset was downloaded, or :code:`False` if the
dataset was already available.
"""
-
- def download_and_unpack_archive(
- url: str, sha256: Optional[str] = None
- ) -> LegacyDataset:
- json_files_before = {
- f
- for f in env.inactive_datasets_site_path.iterdir()
- if f.is_file() and f.name.endswith(".json")
- }
- tar_data = io.BytesIO(download(url, sha256))
- with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc:
- arc.extractall(str(env.inactive_datasets_site_path))
- json_files_after = {
- f
- for f in env.inactive_datasets_site_path.iterdir()
- if f.is_file() and f.name.endswith(".json")
- }
- new_json = json_files_after - json_files_before
- if not len(new_json):
- raise OSError(f"Downloaded dataset {url} contains no metadata JSON file")
- return LegacyDataset.from_json_file(list(new_json)[0])
-
- def unpack_local_archive(path: Path) -> LegacyDataset:
- if not path.is_file():
- raise FileNotFoundError(f"File not found: {path}")
- json_files_before = {
- f
- for f in env.inactive_datasets_site_path.iterdir()
- if f.is_file() and f.name.endswith(".json")
- }
- with tarfile.open(str(path), "r:bz2") as arc:
- arc.extractall(str(env.inactive_datasets_site_path))
- json_files_after = {
- f
- for f in env.inactive_datasets_site_path.iterdir()
- if f.is_file() and f.name.endswith(".json")
- }
- new_json = json_files_after - json_files_before
- if not len(new_json):
- raise OSError(f"Downloaded dataset {url} contains no metadata JSON file")
- return LegacyDataset.from_json_file(list(new_json)[0])
-
- with fasteners.InterProcessLock(env.datasets_site_path / "LOCK"):
- # Resolve the name and URL of the dataset.
- sha256 = None
- if isinstance(dataset, LegacyDataset):
- name, url = dataset.name, dataset.url
- elif isinstance(dataset, str):
- # Check if we have already downloaded the dataset.
- if "://" in dataset:
- name, url = None, dataset
- dataset: Optional[LegacyDataset] = None
- else:
- try:
- dataset: Optional[LegacyDataset] = env.available_datasets[dataset]
- except KeyError:
- raise ValueError(f"Dataset not found: {dataset}")
- name, url, sha256 = dataset.name, dataset.url, dataset.sha256
- else:
- raise TypeError(
- f"require() called with unsupported type: {type(dataset).__name__}"
- )
-
- if dataset and dataset.deprecated:
- warnings.warn(
- f"Dataset '{dataset.name}' is deprecated as of CompilerGym "
- f"release {dataset.deprecated_since}, please update to the "
- "latest available version",
- DeprecationWarning,
- )
-
- # Check if we have already downloaded the dataset.
- if name:
- if (env.datasets_site_path / name).is_dir():
- # Dataset is already downloaded and active.
- return False
- elif not (env.inactive_datasets_site_path / name).is_dir():
- # Dataset is downloaded but inactive.
- name = download_and_unpack_archive(url, sha256=sha256).name
- elif url.startswith("file:///"):
- name = unpack_local_archive(Path(url[len("file:///") :])).name
- else:
- name = download_and_unpack_archive(url, sha256=sha256).name
-
- activate(env, name)
- return True
+ return False
diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py
new file mode 100644
index 000000000..adc97818e
--- /dev/null
+++ b/compiler_gym/datasets/datasets.py
@@ -0,0 +1,260 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from collections import deque
+from typing import Dict, Iterable, Set, TypeVar
+
+from compiler_gym.datasets.benchmark import Benchmark
+from compiler_gym.datasets.dataset import Dataset
+from compiler_gym.datasets.uri import BENCHMARK_URI_RE, resolve_uri_protocol
+
+T = TypeVar("T")
+
+
+def round_robin_iterables(iters: Iterable[Iterable[T]]) -> Iterable[T]:
+ """Yield from the given iterators in round robin order."""
+ # Use a queue of iterators to iterate over. Repeatedly pop an iterator from
+ # the queue, yield the next value from it, then put it at the back of the
+ # queue. The iterator is discarded once exhausted.
+ iters = deque(iters)
+ while len(iters) > 1:
+ it = iters.popleft()
+ try:
+ yield next(it)
+ iters.append(it)
+ except StopIteration:
+ pass
+ # Once we have only a single iterator left, return it directly rather
+ # continuing with the round robin.
+ if len(iters) == 1:
+ yield from iters.popleft()
+
+
+class Datasets(object):
+ """A collection of datasets.
+
+ This class provides a dictionary-like interface for indexing and iterating
+ over multiple :class:`Dataset ` objects.
+ Select a dataset by URI using:
+
+ >>> env.datasets["benchmark://cbench-v1"]
+
+ Check whether a dataset exists using:
+
+ >>> "benchmark://cbench-v1" in env.datasets
+ True
+
+ Or iterate over the datasets using:
+
+ >>> for dataset in env.datasets:
+ ... print(dataset.name)
+ benchmark://cbench-v1
+ benchmark://github-v0
+ benchmark://npb-v0
+
+ To select a benchmark from the datasets, use :meth:`benchmark()`:
+
+ >>> env.datasets.benchmark("benchmark://a-v0/a")
+
+ Use the :meth:`benchmarks()` method to iterate over every benchmark in the
+ datasets in a stable round robin order:
+
+ >>> for benchmark in env.datasets.benchmarks():
+ ... print(benchmark)
+ benchmark://cbench-v1/1
+ benchmark://github-v0/1
+ benchmark://npb-v0/1
+ benchmark://cbench-v1/2
+ ...
+
+ If you want to exclude a dataset, delete it:
+
+ >>> del env.datasets["benchmark://b-v0"]
+ """
+
+ def __init__(
+ self,
+ datasets: Iterable[Dataset],
+ ):
+ self._datasets: Dict[str, Dataset] = {d.name: d for d in datasets}
+ self._visible_datasets: Set[str] = set(
+ name for name, dataset in self._datasets.items() if not dataset.deprecated
+ )
+
+ def datasets(self, with_deprecated: bool = False) -> Iterable[Dataset]:
+ """Enumerate the datasets.
+
+ Dataset order is consistent across runs.
+
+ :param with_deprecated: If :code:`True`, include datasets that have been
+ marked as deprecated.
+
+ :return: An iterable sequence of :meth:`Dataset
+ ` instances.
+ """
+ datasets = self._datasets.values()
+ if not with_deprecated:
+ datasets = (d for d in datasets if not d.deprecated)
+ yield from sorted(datasets, key=lambda d: (d.sort_order, d.name))
+
+ def __iter__(self) -> Iterable[Dataset]:
+ """Iterate over the datasets.
+
+ Dataset order is consistent across runs.
+
+ Equivalent to :meth:`datasets.datasets()
+ `, but without the ability to
+ iterate over the deprecated datasets.
+
+ If the number of benchmarks in any of the datasets is infinite
+ (:code:`len(dataset) == math.inf`), the iterable returned by this method
+ will continue indefinitely.
+
+ :return: An iterable sequence of :meth:`Dataset
+ ` instances.
+ """
+ return self.datasets()
+
+ def dataset(self, dataset: str) -> Dataset:
+ """Get a dataset.
+
+ Return the corresponding :meth:`Dataset
+ `. Name lookup will succeed whether or
+ not the dataset is deprecated.
+
+ :param dataset: A dataset name.
+
+ :return: A :meth:`Dataset ` instance.
+
+ :raises LookupError: If :code:`dataset` is not found.
+ """
+ dataset_name = resolve_uri_protocol(dataset)
+
+ if dataset_name not in self._datasets:
+ raise LookupError(f"Dataset not found: {dataset_name}")
+
+ return self._datasets[dataset_name]
+
+ def __getitem__(self, dataset: str) -> Dataset:
+ """Lookup a dataset.
+
+ :param dataset: A dataset name.
+
+ :return: A :meth:`Dataset ` instance.
+
+ :raises LookupError: If :code:`dataset` is not found.
+ """
+ return self.dataset(dataset)
+
+ def __setitem__(self, key: str, dataset: Dataset):
+ """Add a dataset to the collection.
+
+ :param key: The name of the dataset.
+ :param dataset: The dataset to add.
+ """
+ dataset_name = resolve_uri_protocol(key)
+
+ self._datasets[dataset_name] = dataset
+ if not dataset.deprecated:
+ self._visible_datasets.add(dataset_name)
+
+ def __delitem__(self, dataset: str):
+ """Remove a dataset from the collection.
+
+ This does not affect any underlying storage used by dataset. See
+ :meth:`uninstall() ` to clean
+ up.
+
+ :param dataset: The name of a dataset.
+
+ :return: :code:`True` if the dataset was removed, :code:`False` if it
+ was already removed.
+ """
+ dataset_name = resolve_uri_protocol(dataset)
+ if dataset_name in self._visible_datasets:
+ self._visible_datasets.remove(dataset_name)
+ del self._datasets[dataset_name]
+
+ def __contains__(self, dataset: str) -> bool:
+ """Returns whether the dataset is contained."""
+ try:
+ self.dataset(dataset)
+ return True
+ except LookupError:
+ return False
+
+ def benchmarks(self, with_deprecated: bool = False) -> Iterable[Benchmark]:
+ """Enumerate the (possibly infinite) benchmarks lazily.
+
+ Benchmarks order is consistent across runs. One benchmark from each
+ dataset is returned in round robin order until all datasets have been
+ fully enumerated. The order of :meth:`benchmarks()
+ ` and :meth:`benchmark_uris()
+ ` is the same.
+
+ If the number of benchmarks in any of the datasets is infinite
+ (:code:`len(dataset) == math.inf`), the iterable returned by this method
+ will continue indefinitely.
+
+ :param with_deprecated: If :code:`True`, include benchmarks from
+ datasets that have been marked deprecated.
+
+ :return: An iterable sequence of :class:`Benchmark
+ ` instances.
+ """
+ return round_robin_iterables(
+ (d.benchmarks() for d in self.datasets(with_deprecated=with_deprecated))
+ )
+
+ def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]:
+ """Enumerate the (possibly infinite) benchmark URIs.
+
+ Benchmark URI order is consistent across runs. URIs from datasets are
+ returned in round robin order. The order of :meth:`benchmarks()
+ ` and :meth:`benchmark_uris()
+ ` is the same.
+
+ If the number of benchmarks in any of the datasets is infinite
+ (:code:`len(dataset) == math.inf`), the iterable returned by this method
+ will continue indefinitely.
+
+ :param with_deprecated: If :code:`True`, include benchmarks from
+ datasets that have been marked deprecated.
+
+ :return: An iterable sequence of benchmark URI strings.
+ """
+ return round_robin_iterables(
+ (d.benchmark_uris() for d in self.datasets(with_deprecated=with_deprecated))
+ )
+
+ def benchmark(self, uri: str) -> Benchmark:
+ """Select a benchmark.
+
+ Returns the corresponding :class:`Benchmark
+ `, regardless of whether the containing
+ dataset is installed or deprecated.
+
+ :param uri: The URI of the benchmark to return.
+
+ :return: A :class:`Benchmark `
+ instance.
+ """
+ uri = resolve_uri_protocol(uri)
+
+ match = BENCHMARK_URI_RE.match(uri)
+ if not match:
+ raise ValueError(f"Invalid benchmark URI: '{uri}'")
+
+ dataset_name = match.group("dataset")
+ dataset = self._datasets[dataset_name]
+
+ return dataset.benchmark(uri)
+
+ @property
+ def size(self) -> int:
+ return len(self._visible_datasets)
+
+ def __len__(self) -> int:
+ """The number of datasets in the collection."""
+ return self.size
diff --git a/compiler_gym/datasets/files_dataset.py b/compiler_gym/datasets/files_dataset.py
new file mode 100644
index 000000000..4c522912c
--- /dev/null
+++ b/compiler_gym/datasets/files_dataset.py
@@ -0,0 +1,119 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+from pathlib import Path
+from typing import Iterable, List
+
+from compiler_gym.datasets.dataset import Benchmark, Dataset
+from compiler_gym.util.decorators import memoized_property
+
+
+class FilesDataset(Dataset):
+ """A dataset comprising a directory tree of files.
+
+ A FilesDataset is a root directory that contains (a possibly nested tree of)
+ files, where each file represents a benchmark. The directory contents can be
+ filtered by specifying a filename suffix that files must match.
+
+ The URI of benchmarks is the relative path of each file, stripped of a
+ required filename suffix, if specified. For example, given the following
+ file tree:
+
+ .. code-block::
+
+ /tmp/dataset/a.txt
+ /tmp/dataset/LICENSE
+ /tmp/dataset/subdir/subdir/b.txt
+ /tmp/dataset/subdir/subdir/c.txt
+
+ a FilesDataset :code:`benchmark://ds-v0` rooted at :code:`/tmp/dataset` with
+ filename suffix :code:`.txt` will contain the following URIs:
+
+ >>> list(dataset.benchmark_uris())
+ [
+ "benchmark://ds-v0/a",
+ "benchmark://ds-v0/subdir/subdir/b",
+ "benchmark://ds-v0/subdir/subdir/c",
+ ]
+ """
+
+ def __init__(
+ self,
+ dataset_root: Path,
+ benchmark_file_suffix: str = "",
+ memoize_uris: bool = True,
+ **dataset_args,
+ ):
+ """Constructor.
+
+ :param dataset_root: The root directory to look for benchmark files.
+
+ :param benchmark_file_suffix: A file extension that must be matched for
+ a file to be used as a benchmark.
+
+ :param memoize_uris: Whether to memoize the list of URIs contained in
+ the dataset. Memoizing the URIs enables faster repeated iteration
+ over :meth:`dataset.benchmark_uris()
+ ` at the expense of
+ increased memory overhead as the file list must be kept in memory.
+
+ :param dataset_args: See :meth:`Dataset.__init__()
+ `.
+ """
+ super().__init__(**dataset_args)
+ self.dataset_root = dataset_root
+ self.benchmark_file_suffix = benchmark_file_suffix
+ self.memoize_uris = memoize_uris
+ self._memoized_uris = None
+
+ @memoized_property
+ def size(self) -> int: # pylint: disable=invalid-overriden-method
+ self.install()
+ return sum(
+ sum(1 for f in files if f.endswith(self.benchmark_file_suffix))
+ for (_, _, files) in os.walk(self.dataset_root)
+ )
+
+ @property
+ def _benchmark_uris_iter(self) -> Iterable[str]:
+ """Return an iterator over benchmark URIs that is consistent across runs."""
+ self.install()
+ for root, dirs, files in os.walk(self.dataset_root):
+ # Sort the subdirectories so that os.walk() order is stable between
+ # runs.
+ dirs.sort()
+ reldir = root[len(str(self.dataset_root)) + 1 :]
+ for filename in sorted(files):
+ # If we have an expected file suffix then ignore files that do
+ # not match, and strip the suffix from files that do match.
+ if self.benchmark_file_suffix:
+ if not filename.endswith(self.benchmark_file_suffix):
+ continue
+ filename = filename[: -len(self.benchmark_file_suffix)]
+ # Use os.path.join() rather than simple '/' concatenation as
+ # reldir may be empty.
+ yield os.path.join(self.name, reldir, filename)
+
+ @property
+ def _benchmark_uris(self) -> List[str]:
+ return list(self._benchmark_uris_iter)
+
+ def benchmark_uris(self) -> Iterable[str]:
+ if self._memoized_uris:
+ yield from self._memoized_uris
+ elif self.memoize_uris:
+ self._memoized_uris = self._benchmark_uris
+ yield from self._memoized_uris
+ else:
+ yield from self._benchmark_uris_iter
+
+ def benchmark(self, uri: str) -> Benchmark:
+ self.install()
+
+ relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}"
+ abspath = self.dataset_root / relpath
+ if not abspath.is_file():
+ raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
+ return self.benchmark_class.from_file(uri, abspath)
diff --git a/compiler_gym/datasets/tar_dataset.py b/compiler_gym/datasets/tar_dataset.py
new file mode 100644
index 000000000..55b15c73b
--- /dev/null
+++ b/compiler_gym/datasets/tar_dataset.py
@@ -0,0 +1,219 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import bz2
+import gzip
+import io
+import shutil
+import tarfile
+from threading import Lock
+from typing import Iterable, List, Optional
+
+from fasteners import InterProcessLock
+
+from compiler_gym.datasets.files_dataset import FilesDataset
+from compiler_gym.util.decorators import memoized_property
+from compiler_gym.util.download import download
+from compiler_gym.util.filesystem import atomic_file_write
+
+
+class TarDataset(FilesDataset):
+ """A dataset comprising a files tree stored in a tar archive.
+
+ This extends the :class:`FilesDataset `
+ class by adding support for compressed archives of files. The archive is
+ downloaded and unpacked on-demand.
+ """
+
+ def __init__(
+ self,
+ tar_urls: List[str],
+ tar_sha256: Optional[str] = None,
+ tar_compression: str = "bz2",
+ strip_prefix: str = "",
+ **dataset_args,
+ ):
+ """Constructor.
+
+ :param tar_urls: A list of redundant URLS to download the tar archive from.
+
+ :param tar_sha256: The SHA256 checksum of the downloaded tar archive.
+
+ :param tar_compression: The tar archive compression type. One of
+ {"bz2", "gz"}.
+
+ :param strip_prefix: An optional path prefix to strip. Only files that
+ match this path prefix will be used as benchmarks.
+
+ :param dataset_args: See :meth:`FilesDataset.__init__()
+ `.
+ """
+ super().__init__(
+ dataset_root=None, # Set below once site_data_path is resolved.
+ **dataset_args,
+ )
+ self.dataset_root = self.site_data_path / "contents" / strip_prefix
+
+ self.tar_urls = tar_urls
+ self.tar_sha256 = tar_sha256
+ self.tar_compression = tar_compression
+ self.strip_prefix = strip_prefix
+
+ self._installed = False
+ self._tar_extracted_marker = self.site_data_path / ".extracted"
+ self._tar_lock = Lock()
+ self._tar_lockfile = self.site_data_path / ".install_lock"
+
+ @property
+ def installed(self) -> bool:
+ # Fast path for repeated checks to 'installed' without a disk op.
+ if not self._installed:
+ self._installed = self._tar_extracted_marker.is_file()
+ return self._installed
+
+ def install(self) -> None:
+ super().install()
+
+ if self.installed:
+ return
+
+ # Thread-level and process-level locks to prevent races.
+ with self._tar_lock, InterProcessLock(self._tar_lockfile):
+ # Repeat the check to see if we have already installed the
+ # dataset now that we have acquired the lock.
+ if self.installed:
+ return
+
+ # Remove any partially-completed prior extraction.
+ shutil.rmtree(self.site_data_path / "contents", ignore_errors=True)
+
+ self.logger.info("Downloading %s dataset", self.name)
+ tar_data = io.BytesIO(download(self.tar_urls, self.tar_sha256))
+ self.logger.info("Unpacking %s dataset", self.name)
+ with tarfile.open(
+ fileobj=tar_data, mode=f"r:{self.tar_compression}"
+ ) as arc:
+ arc.extractall(str(self.site_data_path / "contents"))
+
+ # We're done. The last thing we do is create the marker file to
+ # signal to any other install() invocations that the dataset is
+ # ready.
+ self._tar_extracted_marker.touch()
+
+ if self.strip_prefix and not self.dataset_root.is_dir():
+ raise FileNotFoundError(
+ f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'"
+ )
+
+
+class TarDatasetWithManifest(TarDataset):
+ """A tarball-based dataset that reads the benchmark URIs from a separate
+ manifest file.
+
+ A manifest file is a plain text file containing a list of benchmark names,
+ one per line, and is shipped separately from the tar file. The idea is to
+ allow the list of benchmark URIs to be enumerated in a more lightweight
+ manner than downloading and unpacking the entire dataset. It does this by
+ downloading and unpacking only the manifest to iterate over the URIs.
+
+ The manifest file is assumed to be correct and is not validated.
+ """
+
+ def __init__(
+ self,
+ manifest_urls: List[str],
+ manifest_sha256: str,
+ manifest_compression: str = "bz2",
+ **dataset_args,
+ ):
+ """Constructor.
+
+ :param manifest_urls: A list of redundant URLS to download the
+ compressed text file containing a list of benchmark URI suffixes,
+ one per line.
+
+ :param manifest_sha256: The sha256 checksum of the compressed manifest
+ file.
+
+ :param manifest_compression: The manifest compression type. One of
+ {"bz2", "gz"}.
+
+ :param dataset_args: See :meth:`TarDataset.__init__()
+ `.
+ """
+ super().__init__(**dataset_args)
+ self.manifest_urls = manifest_urls
+ self.manifest_sha256 = manifest_sha256
+ self.manifest_compression = manifest_compression
+ self._manifest_path = self.site_data_path / f"manifest-{manifest_sha256}.txt"
+
+ self._manifest_lock = Lock()
+ self._manifest_lockfile = self.site_data_path / ".manifest_lock"
+
+ def _read_manifest(self, manifest_data: str) -> List[str]:
+ """Read the manifest data into a list of URIs. Does not validate the
+ manifest contents.
+ """
+ lines = manifest_data.rstrip().split("\n")
+ return [f"{self.name}/{line}" for line in lines]
+
+ def _read_manifest_file(self) -> List[str]:
+ """Read the benchmark URIs from an on-disk manifest file.
+
+ Does not check that the manifest file exists.
+ """
+ with open(self._manifest_path, encoding="utf-8") as f:
+ uris = self._read_manifest(f.read())
+ self.logger.debug("Read %s manifest, %d entries", self.name, len(uris))
+ return uris
+
+ @memoized_property
+ def _benchmark_uris(self) -> List[str]:
+ """Fetch or download the URI list."""
+ if self._manifest_path.is_file():
+ return self._read_manifest_file()
+
+ # Thread-level and process-level locks to prevent races.
+ with self._manifest_lock, InterProcessLock(self._manifest_lockfile):
+ # Now that we have acquired the lock, repeat the check, since
+ # another thread may have downloaded the manifest.
+ if self._manifest_path.is_file():
+ return self._read_manifest_file()
+
+ # Determine how to decompress the manifest data.
+ decompressor = {
+ "bz2": lambda compressed_data: bz2.BZ2File(compressed_data),
+ "gz": lambda compressed_data: gzip.GzipFile(compressed_data),
+ }.get(self.manifest_compression, None)
+ if not decompressor:
+ raise TypeError(
+ f"Unknown manifest compression: {self.manifest_compression}"
+ )
+
+ # Decompress the manifest data.
+ self.logger.debug("Downloading %s manifest", self.name)
+ manifest_data = io.BytesIO(
+ download(self.manifest_urls, self.manifest_sha256)
+ )
+ with decompressor(manifest_data) as f:
+ manifest_data = f.read()
+
+ # Although we have exclusive-execution locks, we still need to
+ # create the manifest atomically to prevent calls to _benchmark_uris
+ # racing to read an incompletely written manifest.
+ with atomic_file_write(self._manifest_path, fileobj=True) as f:
+ f.write(manifest_data)
+
+ uris = self._read_manifest(manifest_data.decode("utf-8"))
+ self.logger.debug(
+ "Downloaded %s manifest, %d entries", self.name, len(uris)
+ )
+ return uris
+
+ @memoized_property
+ def size(self) -> int:
+ return len(self._benchmark_uris)
+
+ def benchmark_uris(self) -> Iterable[str]:
+ yield from iter(self._benchmark_uris)
diff --git a/compiler_gym/datasets/uri.py b/compiler_gym/datasets/uri.py
new file mode 100644
index 000000000..350e224df
--- /dev/null
+++ b/compiler_gym/datasets/uri.py
@@ -0,0 +1,30 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""This module contains utility code for working with URIs."""
+import re
+
+# Regular expression that matches the full two-part URI prefix of a dataset:
+# {{protocol}}://{{dataset}}
+#
+# An optional trailing slash is permitted.
+#
+# Example matches: "benchmark://foo-v0", "generator://bar-v0/".
+DATASET_NAME_PATTERN = r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/?"
+DATASET_NAME_RE = re.compile(DATASET_NAME_PATTERN)
+
+# Regular expression that matches the full three-part format of a benchmark URI:
+# {{protocol}}://{{dataset}}/{{id}}
+#
+# Example matches: "benchmark://foo-v0/foo" or "generator://bar-v1/foo/bar.txt".
+BENCHMARK_URI_PATTERN = r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/(?P[^\s]+)$"
+BENCHMARK_URI_RE = re.compile(BENCHMARK_URI_PATTERN)
+
+
+def resolve_uri_protocol(uri: str) -> str:
+ """Require that the URI has a protocol by applying a default "benchmark"
+ protocol if none is set."""
+ if "://" not in uri:
+ return f"benchmark://{uri}"
+ return uri
diff --git a/compiler_gym/envs/BUILD b/compiler_gym/envs/BUILD
index ba75aca0e..63519efbe 100644
--- a/compiler_gym/envs/BUILD
+++ b/compiler_gym/envs/BUILD
@@ -21,7 +21,7 @@ py_library(
deps = [
"//compiler_gym:compiler_env_state",
"//compiler_gym:validation_result",
- "//compiler_gym/datasets:dataset",
+ "//compiler_gym/datasets",
"//compiler_gym/service",
"//compiler_gym/service/proto",
"//compiler_gym/spaces",
diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py
index a514f5144..9db13b4f3 100644
--- a/compiler_gym/envs/compiler_env.py
+++ b/compiler_gym/envs/compiler_env.py
@@ -5,8 +5,6 @@
"""This module defines the OpenAI gym interface for compilers."""
import logging
import numbers
-import os
-import sys
import warnings
from collections.abc import Iterable as IterableType
from copy import deepcopy
@@ -15,38 +13,39 @@
from time import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
-import fasteners
import gym
import numpy as np
+from deprecated.sphinx import deprecated
from gym.spaces import Space
from compiler_gym.compiler_env_state import CompilerEnvState
-from compiler_gym.datasets.dataset import LegacyDataset, require
+from compiler_gym.datasets import Benchmark, Dataset, Datasets
from compiler_gym.service import (
CompilerGymServiceConnection,
ConnectionOpts,
ServiceError,
ServiceOSError,
ServiceTransportError,
+ SessionNotFound,
observation_t,
)
from compiler_gym.service.proto import (
AddBenchmarkRequest,
- Benchmark,
EndSessionReply,
EndSessionRequest,
ForkSessionReply,
ForkSessionRequest,
- GetBenchmarksRequest,
GetVersionReply,
GetVersionRequest,
StartSessionRequest,
+ StepReply,
StepRequest,
)
-from compiler_gym.spaces import NamedDiscrete, Reward
+from compiler_gym.spaces import DefaultRewardFromObservation, NamedDiscrete, Reward
from compiler_gym.util.debug_util import get_logging_level
from compiler_gym.util.timer import Timer
-from compiler_gym.validation_result import ValidationError, ValidationResult
+from compiler_gym.validation_error import ValidationError
+from compiler_gym.validation_result import ValidationResult
from compiler_gym.views import ObservationSpaceSpec, ObservationView, RewardView
# Type hints.
@@ -54,31 +53,16 @@
step_t = Tuple[Optional[observation_t], Optional[float], bool, info_t]
-class DefaultRewardFromObservation(Reward):
- def __init__(self, observation_name: str, **kwargs):
- super().__init__(
- observation_spaces=[observation_name], id=observation_name, **kwargs
- )
- self.previous_value: Optional[observation_t] = None
-
- def reset(self, benchmark: str) -> None:
- """Called on env.reset(). Reset incremental progress."""
- del benchmark # unused
- self.previous_value = None
-
- def update(
- self,
- action: int,
- observations: List[observation_t],
- observation_view: ObservationView,
- ) -> float:
- """Called on env.step(). Compute and return new reward."""
- value: float = observations[0]
- if self.previous_value is None:
- self.previous_value = 0
- reward = float(value - self.previous_value)
- self.previous_value = value
- return reward
+def _wrapped_step(
+ service: CompilerGymServiceConnection, request: StepRequest
+) -> StepReply:
+ """Call the Step() RPC endpoint."""
+ try:
+ return service(service.stub.Step, request)
+ except FileNotFoundError as e:
+ if str(e).startswith("Session not found"):
+ raise SessionNotFound(str(e))
+ raise
class CompilerEnv(gym.Env):
@@ -128,14 +112,6 @@ class CompilerEnv(gym.Env):
Default range is (-inf, +inf).
:vartype reward_range: Tuple[float, float]
- :ivar datasets_site_path: The filesystem path used by the service
- to store benchmarks.
- :vartype datasets_site_path: Optional[Path]
-
- :ivar available_datasets: A mapping from dataset name to :class:`LegacyDataset`
- objects that are available to download.
- :vartype available_datasets: Dict[str, LegacyDataset]
-
:ivar observation: A view of the available observation spaces that permits
on-demand computation of observations.
:vartype observation: compiler_gym.views.ObservationView
@@ -154,67 +130,77 @@ def __init__(
self,
service: Union[str, Path],
rewards: Optional[List[Reward]] = None,
+ datasets: Optional[Iterable[Dataset]] = None,
benchmark: Optional[Union[str, Benchmark]] = None,
observation_space: Optional[Union[str, ObservationSpaceSpec]] = None,
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
- logging_level: Optional[int] = None,
+ logger: Optional[logging.Logger] = None,
):
"""Construct and initialize a CompilerGym service environment.
:param service: The hostname and port of a service that implements the
- CompilerGym service interface, or the path of a binary file
- which provides the CompilerGym service interface when executed.
- See :doc:`/compiler_gym/service` for details.
+ CompilerGym service interface, or the path of a binary file which
+ provides the CompilerGym service interface when executed. See
+ :doc:`/compiler_gym/service` for details.
+
:param rewards: The reward spaces that this environment supports.
- Rewards are typically calculated based on observations generated
- by the service. See :class:`Reward ` for
+ Rewards are typically calculated based on observations generated by
+ the service. See :class:`Reward ` for
details.
- :param benchmark: The name of the benchmark to use for this environment.
- The choice of benchmark can be deferred by not providing this
- argument and instead passing by choosing from the
- :code:`CompilerEnv.benchmarks` attribute and passing it to
- :func:`reset()` when called.
+
+ :param benchmark: The benchmark to use for this environment. Either a
+ URI string, or a :class:`Benchmark
+ ` instance. If not provided, the
+ first benchmark as returned by
+ :code:`next(env.datasets.benchmarks())` will be used as the default.
+
:param observation_space: Compute and return observations at each
:func:`step()` from this space. Accepts a string name or an
- :class:`ObservationSpaceSpec `.
- If not provided, :func:`step()` returns :code:`None` for the
- observation value. Can be set later using
- :meth:`env.observation_space `.
- For available spaces, see
- :class:`env.observation.spaces `.
+ :class:`ObservationSpaceSpec
+ `. If not provided,
+ :func:`step()` returns :code:`None` for the observation value. Can
+ be set later using :meth:`env.observation_space
+ `. For available
+ spaces, see :class:`env.observation.spaces
+ `.
+
:param reward_space: Compute and return reward at each :func:`step()`
- from this space. Accepts a string name or a
- :class:`Reward `. If
- not provided, :func:`step()` returns :code:`None` for the reward
- value. Can be set later using
- :meth:`env.reward_space `.
- For available spaces, see
- :class:`env.reward.spaces `.
+ from this space. Accepts a string name or a :class:`Reward
+ `. If not provided, :func:`step()`
+ returns :code:`None` for the reward value. Can be set later using
+ :meth:`env.reward_space
+ `. For available spaces,
+ see :class:`env.reward.spaces `.
+
:param action_space: The name of the action space to use. If not
specified, the default action space for this compiler is used.
+
:param connection_settings: The settings used to establish a connection
with the remote service.
+
:param service_connection: An existing compiler gym service connection
to use.
- :param logging_level: The integer logging level to use for logging. By
- default, the value reported by
- :func:`get_logging_level() ` is
- used.
+
+ :param logger: The logger to use for this environment. If not provided,
+ a :code:`compiler_gym.envs` logger is used and assigned the
+ verbosity returned by :func:`get_logging_level()
+ `.
+
:raises FileNotFoundError: If service is a path to a file that is not
found.
- :raises TimeoutError: If the compiler service fails to initialize
- within the parameters provided in :code:`connection_settings`.
+
+ :raises TimeoutError: If the compiler service fails to initialize within
+ the parameters provided in :code:`connection_settings`.
"""
self.metadata = {"render.modes": ["human", "ansi"]}
- # Set up logging.
- self.logger = logging.getLogger("compiler_gym.envs")
- if logging_level is None:
- logging_level = get_logging_level()
- self.logger.setLevel(logging_level)
+ if logger is None:
+ logger = logging.getLogger("compiler_gym.envs")
+ logger.setLevel(get_logging_level())
+ self.logger = logger
# A compiler service supports multiple simultaneous environments. This
# session ID is used to identify this environment.
@@ -222,8 +208,6 @@ def __init__(
self._service_endpoint: Union[str, Path] = service
self._connection_settings = connection_settings or ConnectionOpts()
- self.datasets_site_path: Optional[Path] = None
- self.available_datasets: Dict[str, LegacyDataset] = {}
self.action_space_name = action_space
@@ -232,6 +216,7 @@ def __init__(
opts=self._connection_settings,
logger=self.logger,
)
+ self.datasets = Datasets(datasets or [])
# If no reward space is specified, generate some from numeric observation spaces
rewards = rewards or [
@@ -245,17 +230,14 @@ def __init__(
]
# The benchmark that is currently being used, and the benchmark that
- # the user requested. Those do not always correlate, since the user
- # could request a random benchmark.
- self._benchmark_in_use_uri: Optional[str] = None
- self._user_specified_benchmark_uri: Optional[str] = None
- # A map from benchmark URIs to Benchmark messages. We keep track of any
- # user-provided custom benchmarks so that we can register them with a
- # reset service.
- self._custom_benchmarks: Dict[str, Benchmark] = {}
+ # will be used on the next call to reset(). These are equal except in
+ # the gap between the user setting the env.benchmark property while in
+ # an episode and the next call to env.reset().
+ self._benchmark_in_use: Optional[Benchmark] = None
+ self._next_benchmark: Optional[Benchmark] = None
# Normally when the benchmark is changed the updated value is not
- # reflected until the next call to reset(). We make an exception for
- # constructor-time arguments as otherwise the behavior of the benchmark
+ # reflected until the next call to reset(). We make an exception for the
+ # constructor-time benchmark as otherwise the behavior of the benchmark
# property is counter-intuitive:
#
# >>> env = gym.make("example-v0", benchmark="foo")
@@ -265,10 +247,17 @@ def __init__(
# >>> env.benchmark
# "foo"
#
- # By forcing the benchmark-in-use URI at constructor time, the first
- # env.benchmark returns the name as expected.
- self.benchmark = benchmark
- self._benchmark_in_use_uri = self._user_specified_benchmark_uri
+ # By forcing the _benchmark_in_use URI at constructor time, the first
+ # env.benchmark above returns the benchmark as expected.
+ try:
+ self.benchmark = benchmark or next(self.datasets.benchmarks())
+ self._benchmark_in_use = self._next_benchmark
+ except StopIteration:
+ # StopIteration raised on next(self.datasets.benchmarks()) if there
+ # are no benchmarks available. This is to allow CompilerEnv to be
+ # used without any datasets by setting a benchmark before/during the
+ # first reset() call.
+ pass
# Process the available action, observation, and reward spaces.
self.action_spaces = [
@@ -276,7 +265,7 @@ def __init__(
for space in self.service.action_spaces
]
self.observation = self._observation_view_type(
- get_observation=lambda req: self.service(self.service.stub.Step, req),
+ get_observation=lambda req: _wrapped_step(self.service, req),
spaces=self.service.observation_spaces,
)
self.reward = self._reward_view_type(rewards, self.observation)
@@ -284,20 +273,33 @@ def __init__(
# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None
- # Mutable state initialized in reset().
self.action_space: Optional[Space] = None
self.observation_space: Optional[Space] = None
+
+ # Mutable state initialized in reset().
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[int] = []
# Initialize the default observation/reward spaces.
- self._default_observation_space: Optional[ObservationSpaceSpec] = None
- self._default_reward_space: Optional[Reward] = None
+ self.observation_space_spec: Optional[ObservationSpaceSpec] = None
+ self.reward_space_spec: Optional[Reward] = None
self.observation_space = observation_space
self.reward_space = reward_space
+ @property
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Use :meth:`env.datasets.datasets() ` instead. "
+ "`More information `_."
+ ),
+ )
+ def available_datasets(self) -> Dict[str, Dataset]:
+ """A dictionary of datasets."""
+ return {d.name: d for d in self.datasets}
+
@property
def versions(self) -> GetVersionReply:
"""Get the version numbers from the compiler service."""
@@ -318,13 +320,15 @@ def compiler_version(self) -> str:
return self.versions.compiler_version
def commandline(self) -> str:
- """Interface for :class:`CompilerEnv` subclasses to provide an equivalent
- commandline invocation to the current environment state.
+ """Interface for :class:`CompilerEnv `
+ subclasses to provide an equivalent commandline invocation to the
+ current environment state.
- See also
- :meth:`commandline_to_actions() `.
+ See also :meth:`commandline_to_actions()
+ `.
- Calling this method on a :class:`CompilerEnv` instance raises
+ Calling this method on a :class:`CompilerEnv
+ ` instance raises
:code:`NotImplementedError`.
:return: A string commandline invocation.
@@ -332,13 +336,15 @@ def commandline(self) -> str:
raise NotImplementedError("abstract method")
def commandline_to_actions(self, commandline: str) -> List[int]:
- """Interface for :class:`CompilerEnv` subclasses to convert from a
- commandline invocation to a sequence of actions.
+ """Interface for :class:`CompilerEnv `
+ subclasses to convert from a commandline invocation to a sequence of
+ actions.
- See also
- :meth:`commandline() `.
+ See also :meth:`commandline()
+ `.
- Calling this method on a :class:`CompilerEnv` instance raises
+ Calling this method on a :class:`CompilerEnv
+ ` instance raises
:code:`NotImplementedError`.
:return: A list of actions.
@@ -356,23 +362,12 @@ def episode_walltime(self) -> float:
def state(self) -> CompilerEnvState:
"""The tuple representation of the current environment state."""
return CompilerEnvState(
- benchmark=self.benchmark,
+ benchmark=str(self.benchmark) if self.benchmark else None,
reward=self.episode_reward,
walltime=self.episode_walltime,
commandline=self.commandline(),
)
- @property
- def inactive_datasets_site_path(self) -> Optional[Path]:
- """The filesystem path used to store inactive benchmarks."""
- if self.datasets_site_path:
- return (
- self.datasets_site_path.parent
- / f"{self.datasets_site_path.name}.inactive"
- )
- else:
- return None
-
@property
def action_space(self) -> NamedDiscrete:
"""The current action space.
@@ -395,69 +390,41 @@ def action_space(self, action_space: Optional[str]):
self._action_space: NamedDiscrete = self.action_spaces[index]
@property
- def benchmark(self) -> Optional[str]:
- """Get or set the name of the benchmark to use.
-
- :getter: Get the name of the current benchmark. Returns :code:`None` if
- :func:`__init__` was not provided a benchmark and :func:`reset` has
- not yet been called.
- :setter: Set the benchmark to use. If :code:`None`, a random benchmark
- is selected by the service on each call to :func:`reset`. Else,
- the same benchmark is used for every episode.
-
- By default, a benchmark will be selected randomly by the service
- from the available :func:`benchmarks` on a call to :func:`reset`. To
- force a specific benchmark to be chosen, set this property (or pass
- the benchmark as an argument to :func:`reset`):
-
- >>> env.benchmark = "benchmark://foo"
- >>> env.reset()
- >>> env.benchmark
- "benchmark://foo"
+ def benchmark(self) -> Benchmark:
+ """Get or set the benchmark to use.
- Once set, all subsequent calls to :func:`reset` will select the same
- benchmark.
+ :getter: Get :class:`Benchmark ` that
+ is currently in use.
- >>> env.benchmark = None
- >>> env.reset() # random benchmark is chosen
+ :setter: Set the benchmark to use. Either a :class:`Benchmark
+ ` instance, or the URI of a
+ benchmark as in :meth:`env.datasets.benchmark_uris()
+ `.
.. note::
- Setting a new benchmark has no effect until :func:`~reset()` is
- called.
- To return to random benchmark selection, set this property to
- :code:`None`:
+ Setting a new benchmark has no effect until
+ :func:`env.reset() ` is called.
"""
- return self._benchmark_in_use_uri
+ return self._benchmark_in_use
@benchmark.setter
- def benchmark(self, benchmark: Optional[Union[str, Benchmark]]):
+ def benchmark(self, benchmark: Union[str, Benchmark]):
if self.in_episode:
warnings.warn(
- "Changing the benchmark has no effect until reset() is called."
+ "Changing the benchmark has no effect until reset() is called"
)
- if benchmark is None:
- self.logger.debug("Unsetting the forced benchmark")
- self._user_specified_benchmark_uri = None
- elif isinstance(benchmark, str):
- self.logger.debug("Setting benchmark by name: %s", benchmark)
- # If the user requested a benchmark by URI, e.g.
- # benchmark://cBench-v1/dijkstra, require the dataset (cBench-v1)
- # automatically.
- if self.datasets_site_path:
- components = benchmark.split("://")
- if len(components) == 1 or components[0] == "benchmark":
- components = components[-1].split("/")
- if len(components) > 1:
- self.logger.info("Requiring dataset %s", components[0])
- self.require_dataset(components[0])
- self._user_specified_benchmark_uri = benchmark
+ if isinstance(benchmark, str):
+ benchmark_object = self.datasets.benchmark(benchmark)
+ self.logger.debug("Setting benchmark by name: %s", benchmark_object)
+ self._next_benchmark = benchmark_object
elif isinstance(benchmark, Benchmark):
- self.logger.debug("Setting benchmark data: %s", benchmark.uri)
- self._user_specified_benchmark_uri = benchmark.uri
- self._add_custom_benchmarks([benchmark])
+ self.logger.debug("Setting benchmark: %s", benchmark.uri)
+ self._next_benchmark = benchmark
else:
- raise TypeError(f"Unsupported benchmark type: {type(benchmark).__name__}")
+ raise TypeError(
+ f"Expected a Benchmark or str, received: '{type(benchmark).__name__}'"
+ )
@property
def reward_space(self) -> Optional[Reward]:
@@ -468,26 +435,27 @@ def reward_space(self) -> Optional[Reward]:
or :code:`None` if not set.
:setter: Set the default reward space.
"""
- return self._default_reward_space
+ return self.reward_space_spec
@reward_space.setter
def reward_space(self, reward_space: Optional[Union[str, Reward]]) -> None:
- if isinstance(reward_space, str) and reward_space not in self.reward.spaces:
- raise LookupError(f"Reward space not found: {reward_space}")
-
- reward_space_name = (
+ # Coerce the observation space into a string.
+ reward_space: Optional[str] = (
reward_space.id if isinstance(reward_space, Reward) else reward_space
)
- self._default_reward: bool = reward_space is not None
- self._default_reward_space: Optional[Reward] = None
- if self._default_reward:
- self._default_reward_space = self.reward.spaces[reward_space_name]
+ if reward_space:
+ if reward_space not in self.reward.spaces:
+ raise LookupError(f"Reward space not found: {reward_space}")
+ self.reward_space_spec = self.reward.spaces[reward_space]
self.reward_range = (
- self._default_reward_space.min,
- self._default_reward_space.max,
+ self.reward_space_spec.min,
+ self.reward_space_spec.max,
)
else:
+ # If no reward space is being used then set the reward range to
+ # unbounded.
+ self.reward_space_spec = None
self.reward_range = (-np.inf, np.inf)
@property
@@ -508,30 +476,26 @@ def observation_space(self) -> Optional[ObservationSpaceSpec]:
:code:`None` if not set.
:setter: Set the default observation space.
"""
- return self._default_observation_space
+ if self.observation_space_spec:
+ return self.observation_space_spec.space
@observation_space.setter
def observation_space(
self, observation_space: Optional[Union[str, ObservationSpaceSpec]]
) -> None:
- if (
- isinstance(observation_space, str)
- and observation_space not in self.observation.spaces
- ):
- raise LookupError(f"Observation space not found: {observation_space}")
-
- observation_space_name = (
+ # Coerce the observation space into a string.
+ observation_space: Optional[str] = (
observation_space.id
if isinstance(observation_space, ObservationSpaceSpec)
else observation_space
)
- self._default_observation = observation_space is not None
- self._default_observation_space: Optional[ObservationSpaceSpec] = None
- if self._default_observation:
- self._default_observation_space = self.observation.spaces[
- observation_space_name
- ]
+ if observation_space:
+ if observation_space not in self.observation.spaces:
+ raise LookupError(f"Observation space not found: {observation_space}")
+ self.observation_space_spec = self.observation.spaces[observation_space]
+ else:
+ self.observation_space_spec = None
def fork(self) -> "CompilerEnv":
"""Fork a new environment with exactly the same state.
@@ -541,33 +505,30 @@ def fork(self) -> "CompilerEnv":
The user must call :meth:`close() `
on the original and new environments.
- :meth:`reset() ` must be called
- before :code:`fork()`.
+ If not already in an episode, :meth:`reset()
+ ` is called.
Example usage:
- >>> env = gym.make("llvm-v0")
- >>> env.reset()
- # ... use env
- >>> new_env = env.fork()
- >>> new_env.state == env.state
- True
- >>> new_env.step(1) == env.step(1)
- True
+ >>> env = gym.make("llvm-v0")
+ >>> env.reset()
+ # ... use env
+ >>> new_env = env.fork()
+ >>> new_env.state == env.state
+ True
+ >>> new_env.step(1) == env.step(1)
+ True
:return: A new environment instance.
"""
if not self.in_episode:
- if self.actions:
- state_to_replay = self.state
+ if self.actions and not self.in_episode:
self.logger.warning(
"Parent service of fork() has died, replaying state"
)
+ self.apply(self.state)
else:
- state_to_replay = None
- self.reset()
- if state_to_replay:
- self.apply(state_to_replay)
+ self.reset()
request = ForkSessionRequest(session_id=self._session_id)
reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request)
@@ -584,23 +545,16 @@ def fork(self) -> "CompilerEnv":
new_env._session_id = reply.session_id # pylint: disable=protected-access
new_env.observation.session_id = reply.session_id
- # Re-register any custom benchmarks with the new environment.
- if self._custom_benchmarks:
- new_env._add_custom_benchmarks( # pylint: disable=protected-access
- list(self._custom_benchmarks.values()).copy()
- )
-
# Now that we have initialized the environment with the current state,
# set the benchmark so that calls to new_env.reset() will correctly
# revert the environment to the initial benchmark state.
- new_env._user_specified_benchmark_uri = ( # pylint: disable=protected-access
- self.benchmark
- )
+ #
+ # pylint: disable=protected-access
+ new_env._next_benchmark = self._benchmark_in_use
+
# Set the "visible" name of the current benchmark to hide the fact that
# we loaded from a custom bitcode file.
- new_env._benchmark_in_use_uri = ( # pylint: disable=protected-access
- self.benchmark
- )
+ new_env._benchmark_in_use = self._benchmark_in_use
# Create copies of the mutable reward and observation spaces. This
# is required to correctly calculate incremental updates.
@@ -610,7 +564,7 @@ def fork(self) -> "CompilerEnv":
# Set the default observation and reward types. Note the use of IDs here
# to prevent passing the spaces by reference.
if self.observation_space:
- new_env.observation_space = self.observation_space.id
+ new_env.observation_space = self.observation_space_spec.id
if self.reward_space:
new_env.reward_space = self.reward_space.id
@@ -625,7 +579,26 @@ def close(self):
"""Close the environment.
Once closed, :func:`reset` must be called before the environment is used
- again."""
+ again.
+
+ .. note::
+
+ Internally, CompilerGym environments may launch subprocesses and use
+ temporary files to communicate between the environment and the
+ underlying compiler (see :ref:`compiler_gym.service
+ ` for details). This
+ means it is important to call :meth:`env.close()
+ ` after use to free up
+ resources and prevent orphan subprocesses or files. We recommend
+ using the :code:`with`-statement pattern for creating environments:
+
+ >>> with gym.make("llvm-autophase-ic-v0") as env:
+ ... env.reset()
+ ... # use env how you like
+
+ This removes the need to call :meth:`env.close()
+ ` yourself.
+ """
# Try and close out the episode, but errors are okay.
close_service = True
if self.in_episode:
@@ -670,19 +643,32 @@ def reset( # pylint: disable=arguments-differ
If no benchmark is provided, and no benchmark was provided to
:func:`__init___`, the service will randomly select a benchmark to
use.
+
:param action_space: The name of the action space to use. If provided,
it overrides any value that set during :func:`__init__`, and
- subsequent calls to :code:`reset()` will use this action space.
- If no aciton space is provided, the default action space is used.
+ subsequent calls to :code:`reset()` will use this action space. If
+ no action space is provided, the default action space is used.
+
:return: The initial observation.
+
+ :raises BenchmarkInitError: If the benchmark is invalid. In this case,
+ another benchmark must be used.
+
+ :raises TypeError: If no benchmark has been set, and the environment
+ does not have a default benchmark to select from.
"""
+ if not self._next_benchmark:
+ raise TypeError(
+ "No benchmark set. Set a benchmark using "
+ "`env.reset(benchmark=benchmark)`. Use `env.datasets` to "
+ "access the available benchmarks."
+ )
+
# Start a new service if required.
if self.service is None:
self.service = CompilerGymServiceConnection(
self._service_endpoint, self._connection_settings
)
- # Re-register the custom benchmarks with the new service.
- self._add_custom_benchmarks(self._custom_benchmarks.values())
self.action_space_name = action_space or self.action_space_name
@@ -694,29 +680,38 @@ def reset( # pylint: disable=arguments-differ
)
self._session_id = None
- # Update the user requested benchmark, if provided. NOTE: This means
- # that env.reset(benchmark=None) does NOT unset a forced benchmark.
+ # Update the user requested benchmark, if provided.
if benchmark:
self.benchmark = benchmark
+ self._benchmark_in_use = self._next_benchmark
+
+ start_session_request = StartSessionRequest(
+ benchmark=self._benchmark_in_use.uri,
+ action_space=(
+ [a.name for a in self.action_spaces].index(self.action_space_name)
+ if self.action_space_name
+ else 0
+ ),
+ observation_space=(
+ [self.observation_space_spec.index] if self.observation_space else None
+ ),
+ )
try:
- reply = self.service(
- self.service.stub.StartSession,
- StartSessionRequest(
- benchmark=self._user_specified_benchmark_uri,
- action_space=(
- [a.name for a in self.action_spaces].index(
- self.action_space_name
- )
- if self.action_space_name
- else 0
- ),
- ),
+ reply = self.service(self.service.stub.StartSession, start_session_request)
+ except FileNotFoundError:
+ # The benchmark was not found, so try adding it and repeating the
+ # request.
+ self.service(
+ self.service.stub.AddBenchmark,
+ AddBenchmarkRequest(benchmark=[self._benchmark_in_use.proto]),
)
+ reply = self.service(self.service.stub.StartSession, start_session_request)
except (ServiceError, ServiceTransportError, TimeoutError) as e:
# Abort and retry on error.
self.logger.warning("%s on reset(): %s", type(e).__name__, e)
- self.service.close()
+ if self.service:
+ self.service.close()
self.service = None
if retry_count >= self._connection_settings.init_max_attempts:
@@ -731,7 +726,6 @@ def reset( # pylint: disable=arguments-differ
retry_count=retry_count + 1,
)
- self._benchmark_in_use_uri = reply.benchmark
self._session_id = reply.session_id
self.observation.session_id = reply.session_id
self.reward.get_cost = self.observation.__getitem__
@@ -749,32 +743,43 @@ def reset( # pylint: disable=arguments-differ
self.episode_reward = 0
if self.observation_space:
- return self.observation[self.observation_space.id]
+ if len(reply.observation) != 1:
+ raise OSError(
+ f"Expected one observation from service, received {len(reply.observation)}"
+ )
+ return self.observation.spaces[self.observation_space_spec.id].translate(
+ reply.observation[0]
+ )
def step(self, action: Union[int, Iterable[int]]) -> step_t:
"""Take a step.
:param action: An action, or a sequence of actions. When multiple
- actions are provided the observation and reward are returned
- after running all of the actions.
+ actions are provided the observation and reward are returned after
+ running all of the actions.
+
:return: A tuple of observation, reward, done, and info. Observation and
- reward are None if default observation/reward is not set. If done
- is True, observation and reward may also be None (e.g. because the
+ reward are None if default observation/reward is not set. If done is
+ True, observation and reward may also be None (e.g. because the
service failed).
+
+ :raises SessionNotFound: If :meth:`reset()
+ ` has not been called.
"""
- assert self.in_episode, "Must call reset() before step()"
+ if not self.in_episode:
+ raise SessionNotFound("Must call reset() before step()")
actions = action if isinstance(action, IterableType) else [action]
observation, reward = None, None
# Build the list of observations that must be computed by the backend
# service to generate the user-requested observation and reward.
- # TODO(cummins): We could de-duplicate this list to improve effiency
+ # TODO(cummins): We could de-duplicate this list to improve efficiency
# when multiple redundant copies of the same observation space are
# requested.
observation_indices, observation_spaces = [], []
if self.observation_space:
- observation_indices.append(self.observation_space.index)
- observation_spaces.append(self.observation_space.id)
+ observation_indices.append(self.observation_space_spec.index)
+ observation_spaces.append(self.observation_space_spec.id)
if self.reward_space:
observation_indices += [
self.observation.spaces[obs].index
@@ -792,14 +797,26 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
observation_space=observation_indices,
)
try:
- reply = self.service(self.service.stub.Step, request)
- except (ServiceError, ServiceTransportError, ServiceOSError, TimeoutError) as e:
+ reply = _wrapped_step(self.service, request)
+ except (
+ ServiceError,
+ ServiceTransportError,
+ ServiceOSError,
+ TimeoutError,
+ SessionNotFound,
+ ) as e:
+ # Gracefully handle "expected" error types. These non-fatal errors
+ # end the current episode and provide some diagnostic information to
+ # the user through the `info` dict.
self.close()
- info = {"error_details": str(e)}
+ info = {
+ "error_type": type(e).__name__,
+ "error_details": str(e),
+ }
if self.reward_space:
reward = self.reward_space.reward_on_error(self.episode_reward)
if self.observation_space:
- observation = self.observation_space.default_value
+ observation = self.observation_space_spec.default_value
return observation, reward, True, info
# If the action space has changed, update it.
@@ -853,7 +870,7 @@ def render(
"""
if not self.observation_space:
raise ValueError("Cannot call render() when no observation space is used")
- observation = self.observation[self.observation_space.id]
+ observation = self.observation[self.observation_space_spec.id]
if mode == "human":
print(observation)
elif mode == "ansi":
@@ -862,10 +879,16 @@ def render(
raise ValueError(f"Invalid mode: {mode}")
@property
- def benchmarks(self) -> List[str]:
- """Enumerate the list of available benchmarks."""
- reply = self.service(self.service.stub.GetBenchmarks, GetBenchmarksRequest())
- return list(reply.benchmark)
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Use :meth:`env.datasets.benchmarks() ` instead. "
+ "`More information `_."
+ ),
+ )
+ def benchmarks(self) -> Iterable[str]:
+ """Enumerate a (possible unbounded) list of available benchmarks."""
+ return self.datasets.benchmark_uris()
def _make_action_space(self, name: str, entries: List[str]) -> Space:
"""Create an action space from the given values.
@@ -895,131 +918,82 @@ def _reward_view_type(self):
"""
return RewardView
- def require_datasets(self, datasets: List[Union[str, LegacyDataset]]) -> bool:
- """Require that the given datasets are available to the environment.
-
- Example usage:
-
- >>> env = gym.make("llvm-v0")
- >>> env.require_dataset(["npb-v0"])
- >>> env.benchmarks
- ["npb-v0/1", "npb-v0/2", ...]
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Datasets are now installed automatically, there is no need to call :code:`require()`. "
+ "`More information `_."
+ ),
+ )
+ def require_datasets(self, datasets: List[Union[str, Dataset]]) -> bool:
+ """Deprecated function for managing datasets.
- This is equivalent to calling
- :meth:`require(self, dataset) ` on
- the list of datasets.
+ Datasets are now installed automatically. See :class:`env.datasets
+ `.
:param datasets: A list of datasets to require. Each dataset is the name
of an available dataset, the URL of a dataset to download, or a
- :class:`LegacyDataset` instance.
+ :class:`Dataset ` instance.
+
:return: :code:`True` if one or more datasets were downloaded, or
:code:`False` if all datasets were already available.
"""
- self.logger.debug("Requiring datasets: %s", datasets)
- dataset_installed = False
- for dataset in datasets:
- dataset_installed |= require(self, dataset)
- if dataset_installed:
- # Signal to the compiler service that the contents of the site data
- # directory has changed.
- self.logger.debug("Initiating service-side scan of dataset directory")
- self.service(
- self.service.stub.AddBenchmark,
- AddBenchmarkRequest(
- benchmark=[Benchmark(uri="service://scan-site-data")]
- ),
- )
- self.make_manifest_file()
- return dataset_installed
+ return False
+
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Use :meth:`env.datasets.require() ` instead. "
+ "`More information `_."
+ ),
+ )
+ def require_dataset(self, dataset: Union[str, Dataset]) -> bool:
+ """Deprecated function for managing datasets.
- def require_dataset(self, dataset: Union[str, LegacyDataset]) -> bool:
- """Require that the given dataset is available to the environment.
+ Datasets are now installed automatically. See :class:`env.datasets
+ `.
- Alias for
- :meth:`env.require_datasets([dataset]) `.
+ :param dataset: The name of the dataset to download, the URL of the
+ dataset, or a :class:`Dataset `
+ instance.
- :param dataset: The name of the dataset to download, the URL of the dataset, or a
- :class:`LegacyDataset` instance.
:return: :code:`True` if the dataset was downloaded, or :code:`False` if
the dataset was already available.
"""
- return self.require_datasets([dataset])
-
- def make_manifest_file(self) -> Path:
- """Create the MANIFEST file.
-
- :return: The path of the manifest file.
- """
- with fasteners.InterProcessLock(self.datasets_site_path / "LOCK"):
- manifest_path = (
- self.datasets_site_path.parent
- / f"{self.datasets_site_path.name}.MANIFEST"
- )
- with open(str(manifest_path), "w") as f:
- for root, _, files in os.walk(self.datasets_site_path):
- print(
- "\n".join(
- [
- f"{root[len(str(self.datasets_site_path)) + 1:]}/{f}"
- for f in files
- if not f.endswith(".json") and f != "LOCK"
- ]
- ),
- file=f,
- )
- return manifest_path
-
- def register_dataset(self, dataset: LegacyDataset) -> bool:
+ return False
+
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Use :meth:`env.datasets.add() ` instead. "
+ "`More information `_."
+ ),
+ )
+ def register_dataset(self, dataset: Dataset) -> bool:
"""Register a new dataset.
- After registering, the dataset name may be used by
- :meth:`require_dataset() `
- to install and activate it.
-
Example usage:
- >>> my_dataset = LegacyDataset(name="my-dataset-v0", ...)
+ >>> my_dataset = Dataset(name="my-dataset-v0", ...)
>>> env = gym.make("llvm-v0")
>>> env.register_dataset(my_dataset)
- >>> env.require_dataset("my-dataset-v0")
>>> env.benchmark = "my-dataset-v0/1"
- :param dataset: A :class:`LegacyDataset` instance describing the new dataset.
+ :param dataset: A :class:`Dataset `
+ instance describing the new dataset.
+
:return: :code:`True` if the dataset was added, else :code:`False`.
+
:raises ValueError: If a dataset with this name is already registered.
"""
- platform = {"darwin": "macos"}.get(sys.platform, sys.platform)
- if platform not in dataset.platforms:
- return False
- if dataset.name in self.available_datasets:
- raise ValueError(f"Dataset already registered with name: {dataset.name}")
- self.available_datasets[dataset.name] = dataset
- return True
-
- def _add_custom_benchmarks(self, benchmarks: List[Benchmark]) -> None:
- """Register custom benchmarks with the compiler service.
-
- Benchmark registration occurs automatically using the
- :meth:`env.benchmark `
- property, there is usually no need to call this method yourself.
-
- :param benchmarks: The benchmarks to register.
- """
- if not benchmarks:
- return
-
- for benchmark in benchmarks:
- self._custom_benchmarks[benchmark.uri] = benchmark
-
- self.service(
- self.service.stub.AddBenchmark,
- AddBenchmarkRequest(benchmark=benchmarks),
- )
+ return self.datasets.add(dataset)
def apply(self, state: CompilerEnvState) -> None: # noqa
"""Replay this state on the given an environment.
- :param env: A :class:`CompilerEnv` instance.
+ :param env: A :class:`CompilerEnv `
+ instance.
+
:raises ValueError: If this state cannot be applied.
"""
if not self.in_episode:
@@ -1039,8 +1013,21 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
)
def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult:
- in_place = state is not None
- state = state or self.state
+ """Validate an environment's state.
+
+ :param state: A state to environment. If not provided, the current state
+ is validated.
+
+ :returns: A :class:`ValidationResult `.
+ """
+ if state:
+ self.reset(benchmark=state.benchmark)
+ in_place = False
+ else:
+ state = self.state
+ in_place = True
+
+ assert self.in_episode
errors: ValidationError = []
validation = {
@@ -1108,13 +1095,10 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
)
)
- # TODO(https://github.com/facebookresearch/CompilerGym/issues/45):
- # Call the new self.benchmark.validation_callback() method
- # once implemented.
- validate_semantics = self.get_benchmark_validation_callback()
- if validate_semantics:
+ benchmark = replay_target.benchmark
+ if benchmark.is_validatable():
validation["benchmark_semantics_validated"] = True
- semantics_errors = list(validate_semantics(self))
+ semantics_errors = benchmark.validate(replay_target)
if semantics_errors:
validation["benchmark_semantics_validation_failed"] = True
errors += semantics_errors
@@ -1124,19 +1108,30 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
finally:
fkd.close()
- return ValidationResult(
+ return ValidationResult.construct(
walltime=walltime.time,
errors=errors,
**validation,
)
+ @deprecated(
+ version="0.1.8",
+ reason=(
+ "Use :meth:`env.validate() "
+ "` instead. "
+ "`More information `_."
+ ),
+ )
def get_benchmark_validation_callback(
self,
) -> Optional[Callable[["CompilerEnv"], Iterable[ValidationError]]]:
- """Return a callback that validates benchmark semantics, if available.
+ """Return a callback that validates benchmark semantics, if available."""
- TODO(https://github.com/facebookresearch/CompilerGym/issues/45): This is
- a temporary placeholder for what will eventually become a method on a
- new Benchmark class.
- """
- return None
+ def composed(env):
+ for validation_cb in self.benchmark.validation_callbacks():
+ errors = validation_cb(env)
+ if errors:
+ yield from errors
+
+ if self.benchmark.validation_callbacks():
+ return composed
diff --git a/compiler_gym/envs/llvm/BUILD b/compiler_gym/envs/llvm/BUILD
index aff87007d..e819ac523 100644
--- a/compiler_gym/envs/llvm/BUILD
+++ b/compiler_gym/envs/llvm/BUILD
@@ -11,36 +11,26 @@ py_library(
":specs.py",
],
data = ["//compiler_gym/envs/llvm/service"],
- visibility = ["//compiler_gym:__subpackages__"],
+ visibility = ["//visibility:public"],
deps = [
- ":benchmarks",
+ ":llvm_benchmark",
":llvm_env",
"//compiler_gym/util",
],
)
py_library(
- name = "benchmarks",
- srcs = ["benchmarks.py"],
+ name = "llvm_benchmark",
+ srcs = ["llvm_benchmark.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
+ "//compiler_gym/datasets",
"//compiler_gym/service/proto",
"//compiler_gym/third_party/llvm",
"//compiler_gym/util",
],
)
-py_library(
- name = "legacy_datasets",
- srcs = ["legacy_datasets.py"],
- visibility = ["//tests:__subpackages__"],
- deps = [
- "//compiler_gym/datasets:dataset",
- "//compiler_gym/third_party/llvm",
- "//compiler_gym/util",
- ],
-)
-
py_library(
name = "llvm_env",
srcs = ["llvm_env.py"],
@@ -48,10 +38,11 @@ py_library(
"//compiler_gym/envs/llvm/service/passes:actions_genfiles",
],
deps = [
- ":benchmarks",
- ":legacy_datasets",
+ ":llvm_benchmark",
":llvm_rewards",
+ "//compiler_gym/datasets",
"//compiler_gym/envs:compiler_env",
+ "//compiler_gym/envs/llvm/datasets",
"//compiler_gym/spaces",
"//compiler_gym/third_party/autophase",
"//compiler_gym/third_party/inst2vec",
diff --git a/compiler_gym/envs/llvm/__init__.py b/compiler_gym/envs/llvm/__init__.py
index 3311282d2..451f8efc0 100644
--- a/compiler_gym/envs/llvm/__init__.py
+++ b/compiler_gym/envs/llvm/__init__.py
@@ -5,7 +5,7 @@
"""Register the LLVM environments."""
from itertools import product
-from compiler_gym.envs.llvm.benchmarks import (
+from compiler_gym.envs.llvm.llvm_benchmark import (
ClangInvocation,
get_system_includes,
make_benchmark,
diff --git a/compiler_gym/envs/llvm/datasets/BUILD b/compiler_gym/envs/llvm/datasets/BUILD
new file mode 100644
index 000000000..cd18149a0
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/BUILD
@@ -0,0 +1,26 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+load("@rules_python//python:defs.bzl", "py_library")
+
+py_library(
+ name = "datasets",
+ srcs = [
+ "__init__.py",
+ "anghabench.py",
+ "cbench.py",
+ "clgen.py",
+ "csmith.py",
+ "llvm_stress.py",
+ "poj104.py",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//compiler_gym/datasets",
+ "//compiler_gym/envs/llvm:llvm_benchmark",
+ "//compiler_gym/service/proto",
+ "//compiler_gym/third_party/llvm",
+ "//compiler_gym/util",
+ ],
+)
diff --git a/compiler_gym/envs/llvm/datasets/__init__.py b/compiler_gym/envs/llvm/datasets/__init__.py
new file mode 100644
index 000000000..e83cf0d6a
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/__init__.py
@@ -0,0 +1,262 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import sys
+from pathlib import Path
+from typing import Iterable, Optional
+
+from compiler_gym.datasets import Dataset, TarDatasetWithManifest
+from compiler_gym.envs.llvm.datasets.anghabench import AnghaBenchDataset
+from compiler_gym.envs.llvm.datasets.cbench import CBenchDataset, CBenchLegacyDataset
+from compiler_gym.envs.llvm.datasets.clgen import CLgenDataset
+from compiler_gym.envs.llvm.datasets.csmith import CsmithBenchmark, CsmithDataset
+from compiler_gym.envs.llvm.datasets.llvm_stress import LlvmStressDataset
+from compiler_gym.envs.llvm.datasets.poj104 import POJ104Dataset, POJ104LegacyDataset
+from compiler_gym.util.runfiles_path import site_data_path
+
+
+class BlasDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://blas-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-blas-v0.tar.bz2"
+ ],
+ tar_sha256="e724a8114709f8480adeb9873d48e426e8d9444b00cddce48e342b9f0f2b096d",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-blas-v0-manifest.bz2"
+ ],
+ manifest_sha256="6946437dcb0da5fad3ed8a7fd83eb4294964198391d5537b1310e22d7ceebff4",
+ references={
+ "Paper": "https://strum355.netsoc.co/books/PDF/Basic%20Linear%20Algebra%20Subprograms%20for%20Fortran%20Usage%20-%20BLAS%20(1979).pdf",
+ "Homepage": "http://www.netlib.org/blas/",
+ },
+ license="BSD 3-Clause",
+ strip_prefix="blas-v0",
+ description="Basic linear algebra kernels",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class GitHubDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ manifest_url, manifest_sha256 = {
+ "darwin": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0-macos-manifest.bz2",
+ "10d933a7d608248be286d756b27813794789f7b87d8561c241d0897fb3238503",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0-linux-manifest.bz2",
+ "aede9ca78657b4694ada9a4592d93f0bbeb3b3bd0fff3b537209850228480d3b",
+ ),
+ }[sys.platform]
+ super().__init__(
+ name="benchmark://github-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-github-v0.tar.bz2"
+ ],
+ tar_sha256="880269dd7a5c2508ea222a2e54c318c38c8090eb105c0a87c595e9dd31720764",
+ manifest_urls=[manifest_url],
+ manifest_sha256=manifest_sha256,
+ license="CC BY 4.0",
+ references={
+ "Paper": "https://arxiv.org/pdf/2012.01470.pdf",
+ },
+ strip_prefix="github-v0",
+ description="Compile-only C/C++ objects from GitHub",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class LinuxDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ manifest_url, manifest_sha256 = {
+ "darwin": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0-macos-manifest.bz2",
+ "dfc87b94c7a43e899e76507398a5af22178aebaebcb5d7e24e82088aeecb0690",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0-linux-manifest.bz2",
+ "32ceb8576f683798010816ac605ee496f386ddbbe64be9e0796015d247a73f92",
+ ),
+ }[sys.platform]
+ super().__init__(
+ name="benchmark://linux-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-linux-v0.tar.bz2"
+ ],
+ tar_sha256="a1ae5c376af30ab042c9e54dc432f89ce75f9ebaee953bc19c08aff070f12566",
+ manifest_urls=[manifest_url],
+ manifest_sha256=manifest_sha256,
+ references={"Homepage": "https://www.linux.org/"},
+ license="GPL-2.0",
+ strip_prefix="linux-v0",
+ description="Compile-only object files from C Linux kernel",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class MibenchDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://mibench-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-mibench-v0.tar.bz2"
+ ],
+ tar_sha256="128c090c40b955b99fdf766da167a5f642018fb35c16a1d082f63be2e977eb13",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-mibench-v0-manifest.bz2"
+ ],
+ manifest_sha256="8ed985d685b48f444a3312cd84ccc5debda4a839850e442a3cdc93910ba0dc5f",
+ references={
+ "Paper": "http://vhosts.eecs.umich.edu/mibench/Publications/MiBench.pdf"
+ },
+ license="BSD 3-Clause",
+ strip_prefix="mibench-v0",
+ description="C benchmarks",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class NPBDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://npb-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-npb-v0.tar.bz2"
+ ],
+ tar_sha256="793ac2e7a4f4ed83709e8a270371e65b724da09eaa0095c52e7f4209f63bb1f2",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-npb-v0-manifest.bz2"
+ ],
+ manifest_sha256="89eccb7f1b0b9e1f82b9b900b9f686ff5b189a2a67a4f8969a15901cd315dba2",
+ references={
+ "Paper": "http://optout.csc.ncsu.edu/~mueller/codeopt/codeopt05/projects/www4.ncsu.edu/~pgauria/csc791a/papers/NAS-95-020.pdf"
+ },
+ license="NASA Open Source Agreement v1.3",
+ strip_prefix="npb-v0",
+ description="NASA Parallel Benchmarks",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class OpenCVDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://opencv-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-opencv-v0.tar.bz2"
+ ],
+ tar_sha256="003df853bd58df93572862ca2f934c7b129db2a3573bcae69a2e59431037205c",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-opencv-v0-manifest.bz2"
+ ],
+ manifest_sha256="8de96f722fab18f3a2a74db74b4038c7947fe8b3da867c9260206fdf5338cd81",
+ references={
+ "Paper": "https://mipro-proceedings.com/sites/mipro-proceedings.com/files/upload/sp/sp_008.pdf",
+ "Homepage": "https://opencv.org/",
+ },
+ license="Apache 2.0",
+ strip_prefix="opencv-v0",
+ description="Compile-only object files from C++ OpenCV library",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+class TensorFlowDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://tensorflow-v0",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-tensorflow-v0.tar.bz2"
+ ],
+ tar_sha256="f77dd1988c772e8359e1303cc9aba0d73d5eb27e0c98415ac3348076ab94efd1",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-tensorflow-v0-manifest.bz2"
+ ],
+ manifest_sha256="cffc45cd10250d483cb093dec913c8a7da64026686284cccf404623bd1da6da8",
+ references={
+ "Paper": "https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf",
+ "Homepage": "https://www.tensorflow.org/",
+ },
+ license="Apache 2.0",
+ strip_prefix="tensorflow-v0",
+ description="Compile-only object files from C++ TensorFlow library",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+
+def get_llvm_datasets(site_data_base: Optional[Path] = None) -> Iterable[Dataset]:
+ """Instantiate the builtin LLVM datasets.
+
+ :param site_data_base: The root of the site data path.
+
+ :return: An iterable sequence of :class:`Dataset
+ ` instances.
+ """
+ site_data_base = site_data_base or site_data_path("llvm-v0")
+
+ yield AnghaBenchDataset(site_data_base=site_data_base, sort_order=0)
+ yield BlasDataset(site_data_base=site_data_base, sort_order=0)
+ yield CLgenDataset(site_data_base=site_data_base, sort_order=0)
+ yield CBenchDataset(site_data_base=site_data_base, sort_order=-1)
+ # Add legacy version of cbench-v1 in which the 'b' was capitalized. This
+ # is deprecated and will be removed no earlier than v0.1.10.
+ yield CBenchDataset(
+ site_data_base=site_data_base,
+ name="benchmark://cBench-v1",
+ deprecated=(
+ "Please use 'benchmark://cbench-v1' (note the lowercase name). "
+ "The dataset is the same, only the name has changed"
+ ),
+ manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-manifest.bz2",
+ manifest_sha256="635b94eeb2784dfedb3b53fd8f84517c3b4b95d851ddb662d4c1058c72dc81e0",
+ sort_order=100,
+ )
+ yield CBenchLegacyDataset(site_data_base=site_data_base)
+ yield CsmithDataset(site_data_base=site_data_base, sort_order=0)
+ yield GitHubDataset(site_data_base=site_data_base, sort_order=0)
+ yield LinuxDataset(site_data_base=site_data_base, sort_order=0)
+ yield LlvmStressDataset(site_data_base=site_data_base, sort_order=0)
+ yield MibenchDataset(site_data_base=site_data_base, sort_order=0)
+ yield NPBDataset(site_data_base=site_data_base, sort_order=0)
+ yield OpenCVDataset(site_data_base=site_data_base, sort_order=0)
+ yield POJ104Dataset(site_data_base=site_data_base, sort_order=0)
+ yield POJ104LegacyDataset(site_data_base=site_data_base, sort_order=100)
+ yield TensorFlowDataset(site_data_base=site_data_base, sort_order=0)
+
+
+__all__ = [
+ "AnghaBenchDataset",
+ "BlasDataset",
+ "CBenchDataset",
+ "CBenchLegacyDataset",
+ "CLgenDataset",
+ "CsmithBenchmark",
+ "CsmithDataset",
+ "get_llvm_datasets",
+ "GitHubDataset",
+ "LinuxDataset",
+ "LlvmStressDataset",
+ "MibenchDataset",
+ "NPBDataset",
+ "OpenCVDataset",
+ "POJ104Dataset",
+ "POJ104LegacyDataset",
+ "TensorFlowDataset",
+]
diff --git a/compiler_gym/envs/llvm/datasets/anghabench.py b/compiler_gym/envs/llvm/datasets/anghabench.py
new file mode 100644
index 000000000..bfcb46a65
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/anghabench.py
@@ -0,0 +1,121 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import subprocess
+import sys
+from concurrent.futures import as_completed
+from pathlib import Path
+
+from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
+from compiler_gym.datasets.benchmark import BenchmarkWithSource
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation
+from compiler_gym.util import thread_pool
+from compiler_gym.util.filesystem import atomic_file_write
+
+
+class AnghaBenchDataset(TarDatasetWithManifest):
+ """A dataset of C programs curated from GitHub source code.
+
+ The dataset is from:
+
+ da Silva, Anderson Faustino, Bruno Conde Kind, José Wesley de Souza
+ Magalhaes, Jerônimo Nunes Rocha, Breno Campos Ferreira Guimaraes, and
+ Fernando Magno Quinão Pereira. "ANGHABENCH: A Suite with One Million
+ Compilable C Benchmarks for Code-Size Reduction." In 2021 IEEE/ACM
+ International Symposium on Code Generation and Optimization (CGO),
+ pp. 378-390. IEEE, 2021.
+
+ And is available at:
+
+ http://cuda.dcc.ufmg.br/angha/home
+
+ Installation
+ ------------
+
+ The AnghaBench dataset consists of C functions that are compiled to LLVM-IR
+ on-demand and cached. The first time each benchmark is used there is an
+ overhead of compiling it from C to bitcode. This is a one-off cost.
+ """
+
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ manifest_url, manifest_sha256 = {
+ "darwin": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-macos-manifest.bz2",
+ "39464256405aacefdb7550a7f990c9c578264c132804eec3daac091fa3c21bd1",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-anghabench-v0-linux-manifest.bz2",
+ "a038d25d39ee9472662a9704dfff19c9e3512ff6a70f1067af85c5cb3784b477",
+ ),
+ }[sys.platform]
+ super().__init__(
+ name="benchmark://anghabench-v0",
+ description="Compile-only C/C++ functions extracted from GitHub",
+ references={
+ "Paper": "https://homepages.dcc.ufmg.br/~fernando/publications/papers/FaustinoCGO21.pdf",
+ "Homepage": "http://cuda.dcc.ufmg.br/angha/",
+ },
+ license="Unknown. See: https://github.com/brenocfg/AnghaBench/issues/1",
+ site_data_base=site_data_base,
+ manifest_urls=[manifest_url],
+ manifest_sha256=manifest_sha256,
+ tar_urls=[
+ "https://github.com/brenocfg/AnghaBench/archive/d8034ac8562b8c978376008f4b33df01b8887b19.tar.gz"
+ ],
+ tar_sha256="85d068e4ce44f2581e3355ee7a8f3ccb92568e9f5bd338bc3a918566f3aff42f",
+ strip_prefix="AnghaBench-d8034ac8562b8c978376008f4b33df01b8887b19",
+ tar_compression="gz",
+ benchmark_file_suffix=".bc",
+ sort_order=sort_order,
+ )
+
+ def benchmark(self, uri: str) -> Benchmark:
+ self.install()
+
+ benchmark_name = uri[len(self.name) + 1 :]
+ if not benchmark_name:
+ raise LookupError(f"No benchmark specified: {uri}")
+
+ # The absolute path of the file, without an extension.
+ path_stem = self.dataset_root / benchmark_name
+
+ bitcode_abspath = Path(f"{path_stem}.bc")
+ c_file_abspath = Path(f"{path_stem}.c")
+
+ # If the file does not exist, compile it on-demand.
+ if not bitcode_abspath.is_file():
+ if not c_file_abspath.is_file():
+ raise LookupError(
+ f"Benchmark not found: {uri} (file not found: {c_file_abspath})"
+ )
+
+ with atomic_file_write(bitcode_abspath) as tmp_path:
+ compile_cmd = ClangInvocation.from_c_file(
+ c_file_abspath,
+ copt=[
+ "-ferror-limit=1", # Stop on first error.
+ "-w", # No warnings.
+ ],
+ ).command(outpath=tmp_path)
+ subprocess.check_call(compile_cmd, timeout=300)
+
+ return BenchmarkWithSource.create(
+ uri, bitcode_abspath, "function.c", c_file_abspath
+ )
+
+ def compile_all(self):
+ n = self.size
+ executor = thread_pool.get_thread_pool_executor()
+ # Since the dataset is lazily compiled, simply iterating over the full
+ # set of URIs will compile everything. Do this in parallel.
+ futures = (
+ executor.submit(self.benchmark, uri) for uri in self.benchmark_uris()
+ )
+ for i, future in enumerate(as_completed(futures), start=1):
+ future.result()
+ print(
+ f"\r\033[KCompiled {i} of {n} programs ({i/n:.1%} complete)",
+ flush=True,
+ end="",
+ )
diff --git a/compiler_gym/envs/llvm/datasets/cbench.py b/compiler_gym/envs/llvm/datasets/cbench.py
new file mode 100644
index 000000000..f8587ae2d
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/cbench.py
@@ -0,0 +1,846 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import enum
+import io
+import logging
+import os
+import re
+import shutil
+import subprocess
+import sys
+import tarfile
+import tempfile
+from collections import defaultdict
+from pathlib import Path
+from threading import Lock
+from typing import Callable, Dict, List, NamedTuple, Optional
+
+import fasteners
+
+from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
+from compiler_gym.third_party import llvm
+from compiler_gym.util.download import download
+from compiler_gym.util.runfiles_path import cache_path, site_data_path
+from compiler_gym.util.timer import Timer
+from compiler_gym.validation_result import ValidationError
+
+_CBENCH_TARS = {
+ "macos": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-macos.tar.bz2",
+ "90b312b40317d9ee9ed09b4b57d378879f05e8970bb6de80dc8581ad0e36c84f",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v1-linux.tar.bz2",
+ "601fff3944c866f6617e653b6eb5c1521382c935f56ca1f36a9f5cf1a49f3de5",
+ ),
+}
+
+_CBENCH_RUNTOME_DATA = (
+ "https://dl.fbaipublicfiles.com/compiler_gym/cBench-v0-runtime-data.tar.bz2",
+ "a1b5b5d6b115e5809ccaefc2134434494271d184da67e2ee43d7f84d07329055",
+)
+
+
+if sys.platform == "darwin":
+ _COMPILE_ARGS = [
+ "-L",
+ "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib",
+ ]
+else:
+ _COMPILE_ARGS = []
+
+
+class LlvmSanitizer(enum.IntEnum):
+ """The LLVM sanitizers."""
+
+ ASAN = 1
+ TSAN = 2
+ MSAN = 3
+ UBSAN = 4
+
+
+# Compiler flags that are enabled by sanitizers.
+_SANITIZER_FLAGS = {
+ LlvmSanitizer.ASAN: ["-O1", "-g", "-fsanitize=address", "-fno-omit-frame-pointer"],
+ LlvmSanitizer.TSAN: ["-O1", "-g", "-fsanitize=thread"],
+ LlvmSanitizer.MSAN: ["-O1", "-g", "-fsanitize=memory"],
+ LlvmSanitizer.UBSAN: ["-fsanitize=undefined"],
+}
+
+
+class BenchmarkExecutionResult(NamedTuple):
+ """The result of running a benchmark."""
+
+ walltime_seconds: float
+ """The execution time in seconds."""
+
+ error: Optional[ValidationError] = None
+ """An error."""
+
+ output: Optional[str] = None
+ """The output generated by the benchmark."""
+
+ def json(self):
+ return self._asdict() # pylint: disable=no-member
+
+
+def _compile_and_run_bitcode_file(
+ bitcode_file: Path,
+ cmd: str,
+ cwd: Path,
+ linkopts: List[str],
+ env: Dict[str, str],
+ num_runs: int,
+ logger: logging.Logger,
+ sanitizer: Optional[LlvmSanitizer] = None,
+ timeout_seconds: float = 300,
+ compilation_timeout_seconds: float = 60,
+) -> BenchmarkExecutionResult:
+ """Run the given cBench benchmark."""
+ # cBench benchmarks expect that a file _finfo_dataset exists in the
+ # current working directory and contains the number of benchmark
+ # iterations in it.
+ with open(cwd / "_finfo_dataset", "w") as f:
+ print(num_runs, file=f)
+
+ # Create a barebones execution environment for the benchmark.
+ run_env = {
+ "TMPDIR": os.environ.get("TMPDIR", ""),
+ "HOME": os.environ.get("HOME", ""),
+ "USER": os.environ.get("USER", ""),
+ # Disable all logging from GRPC. In the past I have had false-positive
+ # "Wrong output" errors caused by GRPC error messages being logged to
+ # stderr.
+ "GRPC_VERBOSITY": "NONE",
+ }
+ run_env.update(env)
+
+ error_data = {}
+
+ if sanitizer:
+ clang_path = llvm.clang_path()
+ binary = cwd / "a.out"
+ error_data["run_cmd"] = cmd.replace("$BIN", "./a.out")
+ # Generate the a.out binary file.
+ compile_cmd = (
+ [clang_path.name, str(bitcode_file), "-o", str(binary)]
+ + _COMPILE_ARGS
+ + list(linkopts)
+ + _SANITIZER_FLAGS.get(sanitizer, [])
+ )
+ error_data["compile_cmd"] = compile_cmd
+ logger.debug("compile: %s", compile_cmd)
+ assert not binary.is_file()
+ clang = subprocess.Popen(
+ compile_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ env={"PATH": f"{clang_path.parent}:{os.environ.get('PATH', '')}"},
+ )
+ try:
+ output, _ = clang.communicate(timeout=compilation_timeout_seconds)
+ except subprocess.TimeoutExpired:
+ clang.kill()
+ error_data["timeout"] = compilation_timeout_seconds
+ return BenchmarkExecutionResult(
+ walltime_seconds=timeout_seconds,
+ error=ValidationError(
+ type="Compilation timeout",
+ data=error_data,
+ ),
+ )
+ if clang.returncode:
+ error_data["output"] = output
+ return BenchmarkExecutionResult(
+ walltime_seconds=timeout_seconds,
+ error=ValidationError(
+ type="Compilation failed",
+ data=error_data,
+ ),
+ )
+ assert binary.is_file()
+ else:
+ lli_path = llvm.lli_path()
+ error_data["run_cmd"] = cmd.replace("$BIN", f"{lli_path.name} benchmark.bc")
+ run_env["PATH"] = str(lli_path.parent)
+
+ try:
+ logger.debug("exec: %s", error_data["run_cmd"])
+ process = subprocess.Popen(
+ error_data["run_cmd"],
+ shell=True,
+ stderr=subprocess.STDOUT,
+ stdout=subprocess.PIPE,
+ env=run_env,
+ cwd=cwd,
+ )
+
+ with Timer() as timer:
+ stdout, _ = process.communicate(timeout=timeout_seconds)
+ except subprocess.TimeoutExpired:
+ process.kill()
+ error_data["timeout_seconds"] = timeout_seconds
+ return BenchmarkExecutionResult(
+ walltime_seconds=timeout_seconds,
+ error=ValidationError(
+ type="Execution timeout",
+ data=error_data,
+ ),
+ )
+ finally:
+ if sanitizer:
+ binary.unlink()
+
+ try:
+ output = stdout.decode("utf-8")
+ except UnicodeDecodeError:
+ output = ""
+
+ if process.returncode:
+ # Runtime error.
+ if sanitizer == LlvmSanitizer.ASAN and "LeakSanitizer" in output:
+ error_type = "Memory leak"
+ elif sanitizer == LlvmSanitizer.ASAN and "AddressSanitizer" in output:
+ error_type = "Memory error"
+ elif sanitizer == LlvmSanitizer.MSAN and "MemorySanitizer" in output:
+ error_type = "Memory error"
+ elif "Segmentation fault" in output:
+ error_type = "Segmentation fault"
+ elif "Illegal Instruction" in output:
+ error_type = "Illegal Instruction"
+ else:
+ error_type = f"Runtime error ({process.returncode})"
+
+ error_data["return_code"] = process.returncode
+ error_data["output"] = output
+ return BenchmarkExecutionResult(
+ walltime_seconds=timer.time,
+ error=ValidationError(
+ type=error_type,
+ data=error_data,
+ ),
+ )
+ return BenchmarkExecutionResult(walltime_seconds=timer.time, output=output)
+
+
+def download_cBench_runtime_data() -> bool:
+ """Download and unpack the cBench runtime dataset."""
+ cbench_data = site_data_path("llvm-v0/cbench-v1-runtime-data/runtime_data")
+ if (cbench_data / "unpacked").is_file():
+ return False
+ else:
+ # Clean up any partially-extracted data directory.
+ if cbench_data.is_dir():
+ shutil.rmtree(cbench_data)
+
+ url, sha256 = _CBENCH_RUNTOME_DATA
+ tar_contents = io.BytesIO(download(url, sha256))
+ with tarfile.open(fileobj=tar_contents, mode="r:bz2") as tar:
+ cbench_data.parent.mkdir(parents=True, exist_ok=True)
+ tar.extractall(cbench_data.parent)
+ assert cbench_data.is_dir()
+ # Create the marker file to indicate that the directory is unpacked
+ # and ready to go.
+ (cbench_data / "unpacked").touch()
+ return True
+
+
+# Thread lock to prevent race on download_cBench_runtime_data() from
+# multi-threading. This works in tandem with the inter-process file lock - both
+# are required.
+_CBENCH_DOWNLOAD_THREAD_LOCK = Lock()
+
+
+def _make_cBench_validator(
+ cmd: str,
+ linkopts: List[str],
+ os_env: Dict[str, str],
+ num_runs: int = 1,
+ compare_output: bool = True,
+ input_files: Optional[List[Path]] = None,
+ output_files: Optional[List[Path]] = None,
+ validate_result: Optional[
+ Callable[[BenchmarkExecutionResult], Optional[str]]
+ ] = None,
+ pre_execution_callback: Optional[Callable[[Path], None]] = None,
+ sanitizer: Optional[LlvmSanitizer] = None,
+ flakiness: int = 5,
+) -> Callable[["LlvmEnv"], Optional[ValidationError]]: # noqa: F821
+ """Construct a validation callback for a cBench benchmark. See validator() for usage."""
+ input_files = input_files or []
+ output_files = output_files or []
+
+ def validator_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
+ """The validation callback."""
+ with _CBENCH_DOWNLOAD_THREAD_LOCK:
+ with fasteners.InterProcessLock(cache_path("cbench-v1-runtime-data.LOCK")):
+ download_cBench_runtime_data()
+
+ cbench_data = site_data_path("llvm-v0/cbench-v1-runtime-data/runtime_data")
+ for input_file_name in input_files:
+ path = cbench_data / input_file_name
+ if not path.is_file():
+ raise FileNotFoundError(f"Required benchmark input not found: {path}")
+
+ # Create a temporary working directory to execute the benchmark in.
+ with tempfile.TemporaryDirectory(dir=env.service.connection.working_dir) as d:
+ cwd = Path(d)
+
+ # Expand shell variable substitutions in the benchmark command.
+ expanded_command = cmd.replace("$D", str(cbench_data))
+
+ # Translate the output file names into paths inside the working
+ # directory.
+ output_paths = [cwd / o for o in output_files]
+
+ if pre_execution_callback:
+ pre_execution_callback(cwd)
+
+ # Produce a gold-standard output using a reference version of
+ # the benchmark.
+ if compare_output or output_files:
+ gs_env = env.fork()
+ try:
+ # Reset to the original benchmark state and compile it.
+ gs_env.reset(benchmark=env.benchmark)
+ gs_env.write_bitcode(cwd / "benchmark.bc")
+ gold_standard = _compile_and_run_bitcode_file(
+ bitcode_file=cwd / "benchmark.bc",
+ cmd=expanded_command,
+ cwd=cwd,
+ num_runs=1,
+ # Use default optimizations for gold standard.
+ linkopts=linkopts + ["-O2"],
+ # Always assume safe.
+ sanitizer=None,
+ logger=env.logger,
+ env=os_env,
+ )
+ if gold_standard.error:
+ return ValidationError(
+ type=f"Gold standard: {gold_standard.error.type}",
+ data=gold_standard.error.data,
+ )
+ finally:
+ gs_env.close()
+
+ # Check that the reference run produced the expected output
+ # files.
+ for path in output_paths:
+ if not path.is_file():
+ try:
+ output = gold_standard.output
+ except UnicodeDecodeError:
+ output = ""
+ raise FileNotFoundError(
+ f"Expected file '{path.name}' not generated\n"
+ f"Benchmark: {env.benchmark}\n"
+ f"Command: {cmd}\n"
+ f"Output: {output}"
+ )
+ path.rename(f"{path}.gold_standard")
+
+ # Serialize the benchmark to a bitcode file that will then be
+ # compiled to a binary.
+ env.write_bitcode(cwd / "benchmark.bc")
+ outcome = _compile_and_run_bitcode_file(
+ bitcode_file=cwd / "benchmark.bc",
+ cmd=expanded_command,
+ cwd=cwd,
+ num_runs=num_runs,
+ linkopts=linkopts,
+ sanitizer=sanitizer,
+ logger=env.logger,
+ env=os_env,
+ )
+
+ if outcome.error:
+ return outcome.error
+
+ # Run a user-specified validation hook.
+ if validate_result:
+ validate_result(outcome)
+
+ # Difftest the console output.
+ if compare_output and gold_standard.output != outcome.output:
+ return ValidationError(
+ type="Wrong output",
+ data={"expected": gold_standard.output, "actual": outcome.output},
+ )
+
+ # Difftest the output files.
+ for path in output_paths:
+ if not path.is_file():
+ return ValidationError(
+ type="Output not generated",
+ data={"path": path.name, "command": cmd},
+ )
+ diff = subprocess.Popen(
+ ["diff", str(path), f"{path}.gold_standard"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ )
+ stdout, _ = diff.communicate()
+ if diff.returncode:
+ try:
+ stdout = stdout.decode("utf-8")
+ return ValidationError(
+ type="Wrong output (file)",
+ data={"path": path.name, "diff": stdout},
+ )
+ except UnicodeDecodeError:
+ return ValidationError(
+ type="Wrong output (file)",
+ data={"path": path.name, "diff": ""},
+ )
+
+ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
+ """Wrap the validation callback in a flakiness retry loop."""
+ for j in range(1, max(flakiness, 1) + 1):
+ try:
+ error = validator_cb(env)
+ if not error:
+ return
+ except TimeoutError:
+ # Timeout errors can be raised by the environment in case of a
+ # slow step / observation, and should be retried.
+ pass
+ env.logger.warning(
+ "Validation callback failed, attempt=%d/%d", j, flakiness
+ )
+ return error
+
+ return flaky_wrapped_cb
+
+
+def validator(
+ benchmark: str,
+ cmd: str,
+ data: Optional[List[str]] = None,
+ outs: Optional[List[str]] = None,
+ platforms: Optional[List[str]] = None,
+ compare_output: bool = True,
+ validate_result: Optional[
+ Callable[[BenchmarkExecutionResult], Optional[str]]
+ ] = None,
+ linkopts: Optional[List[str]] = None,
+ env: Optional[Dict[str, str]] = None,
+ pre_execution_callback: Optional[Callable[[], None]] = None,
+ sanitizers: Optional[List[LlvmSanitizer]] = None,
+) -> bool:
+ """Declare a new benchmark validator.
+
+ TODO(cummins): Pull this out into a public API.
+
+ :param benchmark: The name of the benchmark that this validator supports.
+ :cmd: The shell command to run the validation. Variable substitution is
+ applied to this value as follows: :code:`$BIN` is replaced by the path
+ of the compiled binary and :code:`$D` is replaced with the path to the
+ benchmark's runtime data directory.
+ :data: A list of paths to input files.
+ :outs: A list of paths to output files.
+ :return: :code:`True` if the new validator was registered, else :code:`False`.
+ """
+ platforms = platforms or ["linux", "macos"]
+ if {"darwin": "macos"}.get(sys.platform, sys.platform) not in platforms:
+ return False
+ infiles = data or []
+ outfiles = [Path(p) for p in outs or []]
+ linkopts = linkopts or []
+ env = env or {}
+ if sanitizers is None:
+ sanitizers = LlvmSanitizer
+
+ VALIDATORS[benchmark].append(
+ _make_cBench_validator(
+ cmd=cmd,
+ input_files=infiles,
+ output_files=outfiles,
+ compare_output=compare_output,
+ validate_result=validate_result,
+ linkopts=linkopts,
+ os_env=env,
+ pre_execution_callback=pre_execution_callback,
+ )
+ )
+
+ # Register additional validators using the sanitizers.
+ if sys.platform.startswith("linux"):
+ for sanitizer in sanitizers:
+ VALIDATORS[benchmark].append(
+ _make_cBench_validator(
+ cmd=cmd,
+ input_files=infiles,
+ output_files=outfiles,
+ compare_output=compare_output,
+ validate_result=validate_result,
+ linkopts=linkopts,
+ os_env=env,
+ pre_execution_callback=pre_execution_callback,
+ sanitizer=sanitizer,
+ )
+ )
+
+ return True
+
+
+class CBenchBenchmark(Benchmark):
+ """A cBench benchmmark."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for val in VALIDATORS.get(self.uri, []):
+ self.add_validation_callback(val)
+
+
+class CBenchDataset(TarDatasetWithManifest):
+ def __init__(
+ self,
+ site_data_base: Path,
+ sort_order: int = 0,
+ name="benchmark://cbench-v1",
+ manifest_url="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cbench-v1-manifest.bz2",
+ manifest_sha256="eeffd7593aeb696a160fd22e6b0c382198a65d0918b8440253ea458cfe927741",
+ deprecated=None,
+ ):
+ platform = {"darwin": "macos"}.get(sys.platform, sys.platform)
+ url, sha256 = _CBENCH_TARS[platform]
+ super().__init__(
+ name=name,
+ description="Runnable C benchmarks",
+ license="BSD 3-Clause",
+ references={
+ "Paper": "https://arxiv.org/pdf/1407.3487.pdf",
+ "Homepage": "https://ctuning.org/wiki/index.php/CTools:CBench",
+ },
+ tar_urls=[url],
+ tar_sha256=sha256,
+ manifest_urls=[manifest_url],
+ manifest_sha256=manifest_sha256,
+ strip_prefix="cBench-v1",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ benchmark_class=CBenchBenchmark,
+ deprecated=deprecated,
+ validatable="Partially",
+ )
+
+
+# URLs of the deprecated cBench datasets.
+_CBENCH_LEGACY_TARS = {
+ "macos": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-macos.tar.bz2",
+ "072a730c86144a07bba948c49afe543e4f06351f1cb17f7de77f91d5c1a1b120",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-linux.tar.bz2",
+ "9b5838a90895579aab3b9375e8eeb3ed2ae58e0ad354fec7eb4f8b31ecb4a360",
+ ),
+}
+
+
+class CBenchLegacyDataset(TarDatasetWithManifest):
+ # The difference between cbench-v0 and cbench-v1 is the arguments passed to
+ # clang when preparing the LLVM bitcodes:
+ #
+ # - v0: `-O0 -Xclang -disable-O0-optnone`.
+ # - v1: `-O1 -Xclang -Xclang -disable-llvm-passes`.
+ #
+ # The key difference with is that in v0, the generated IR functions were
+ # annotated with a `noinline` attribute that prevented inline. In v1 that is
+ # no longer the case.
+ def __init__(self, site_data_base: Path):
+ platform = {"darwin": "macos"}.get(sys.platform, sys.platform)
+ url, sha256 = _CBENCH_LEGACY_TARS[platform]
+ super().__init__(
+ name="benchmark://cBench-v0",
+ description="Runnable C benchmarks",
+ license="BSD 3-Clause",
+ references={
+ "Paper": "https://arxiv.org/pdf/1407.3487.pdf",
+ "Homepage": "https://ctuning.org/wiki/index.php/CTools:CBench",
+ },
+ tar_urls=[url],
+ tar_sha256=sha256,
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-cBench-v0-manifest.bz2"
+ ],
+ manifest_sha256="635b94eeb2784dfedb3b53fd8f84517c3b4b95d851ddb662d4c1058c72dc81e0",
+ strip_prefix="cBench-v0",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ deprecated="Please use 'benchmark://cbench-v1'",
+ )
+
+
+# ===============================
+# Definition of cBench validators
+# ===============================
+
+
+# A map from benchmark name to validation callbacks.
+VALIDATORS: Dict[
+ str, List[Callable[["LlvmEnv"], Optional[str]]] # noqa: F821
+] = defaultdict(list)
+
+
+def validate_sha_output(result: BenchmarkExecutionResult) -> Optional[str]:
+ """SHA benchmark prints 5 random hex strings. Normally these hex strings are
+ 16 characters but occasionally they are less (presumably because of a
+ leading zero being omitted).
+ """
+ try:
+ if not re.match(
+ r"[0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16} [0-9a-f]{0,16}",
+ result.output.rstrip(),
+ ):
+ return "Failed to parse hex output"
+ except UnicodeDecodeError:
+ return "Failed to parse unicode output"
+
+
+def setup_ghostscript_library_files(dataset_id: int) -> Callable[[Path], None]:
+ """Make a pre-execution setup hook for ghostscript."""
+
+ def setup(cwd: Path):
+ cbench_data = site_data_path("llvm-v0/cbench-v1-runtime-data/runtime_data")
+ # Copy the input data file into the current directory since ghostscript
+ # doesn't like long input paths.
+ shutil.copyfile(
+ cbench_data / "office_data" / f"{dataset_id}.ps", cwd / "input.ps"
+ )
+ # Ghostscript doesn't like the library files being symlinks so copy them
+ # into the working directory as regular files.
+ for path in (cbench_data / "ghostscript").iterdir():
+ if path.name.endswith(".ps"):
+ shutil.copyfile(path, cwd / path.name)
+
+ return setup
+
+
+validator(
+ benchmark="benchmark://cbench-v1/bitcount",
+ cmd="$BIN 1125000",
+)
+
+validator(
+ benchmark="benchmark://cbench-v1/bitcount",
+ cmd="$BIN 512",
+)
+
+for i in range(1, 21):
+
+ # NOTE(cummins): Disabled due to timeout errors, further investigation
+ # needed.
+ #
+ # validator(
+ # benchmark="benchmark://cbench-v1/adpcm",
+ # cmd=f"$BIN $D/telecom_data/{i}.adpcm",
+ # data=[f"telecom_data/{i}.adpcm"],
+ # )
+ #
+ # validator(
+ # benchmark="benchmark://cbench-v1/adpcm",
+ # cmd=f"$BIN $D/telecom_data/{i}.pcm",
+ # data=[f"telecom_data/{i}.pcm"],
+ # )
+
+ validator(
+ benchmark="benchmark://cbench-v1/blowfish",
+ cmd=f"$BIN d $D/office_data/{i}.benc output.txt 1234567890abcdeffedcba0987654321",
+ data=[f"office_data/{i}.benc"],
+ outs=["output.txt"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/bzip2",
+ cmd=f"$BIN -d -k -f -c $D/bzip2_data/{i}.bz2",
+ data=[f"bzip2_data/{i}.bz2"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/crc32",
+ cmd=f"$BIN $D/telecom_data/{i}.pcm",
+ data=[f"telecom_data/{i}.pcm"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/dijkstra",
+ cmd=f"$BIN $D/network_dijkstra_data/{i}.dat",
+ data=[f"network_dijkstra_data/{i}.dat"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/gsm",
+ cmd=f"$BIN -fps -c $D/telecom_gsm_data/{i}.au",
+ data=[f"telecom_gsm_data/{i}.au"],
+ )
+
+ # NOTE(cummins): ispell fails with returncode 1 and no output when run
+ # under safe optimizations.
+ #
+ # validator(
+ # benchmark="benchmark://cbench-v1/ispell",
+ # cmd=f"$BIN -a -d americanmed+ $D/office_data/{i}.txt",
+ # data = [f"office_data/{i}.txt"],
+ # )
+
+ validator(
+ benchmark="benchmark://cbench-v1/jpeg-c",
+ cmd=f"$BIN -dct int -progressive -outfile output.jpeg $D/consumer_jpeg_data/{i}.ppm",
+ data=[f"consumer_jpeg_data/{i}.ppm"],
+ outs=["output.jpeg"],
+ # NOTE(cummins): AddressSanitizer disabled because of
+ # global-buffer-overflow in regular build.
+ sanitizers=[LlvmSanitizer.TSAN, LlvmSanitizer.UBSAN],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/jpeg-d",
+ cmd=f"$BIN -dct int -outfile output.ppm $D/consumer_jpeg_data/{i}.jpg",
+ data=[f"consumer_jpeg_data/{i}.jpg"],
+ outs=["output.ppm"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/patricia",
+ cmd=f"$BIN $D/network_patricia_data/{i}.udp",
+ data=[f"network_patricia_data/{i}.udp"],
+ env={
+ # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
+ "ASAN_OPTIONS": "detect_leaks=0",
+ },
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/qsort",
+ cmd=f"$BIN $D/automotive_qsort_data/{i}.dat",
+ data=[f"automotive_qsort_data/{i}.dat"],
+ outs=["sorted_output.dat"],
+ linkopts=["-lm"],
+ )
+
+ # NOTE(cummins): Rijndael benchmark disabled due to memory errors under
+ # basic optimizations.
+ #
+ # validator(benchmark="benchmark://cbench-v1/rijndael", cmd=f"$BIN
+ # $D/office_data/{i}.enc output.dec d
+ # 1234567890abcdeffedcba09876543211234567890abcdeffedcba0987654321",
+ # data=[f"office_data/{i}.enc"], outs=["output.dec"],
+ # )
+ #
+ # validator(benchmark="benchmark://cbench-v1/rijndael", cmd=f"$BIN
+ # $D/office_data/{i}.txt output.enc e
+ # 1234567890abcdeffedcba09876543211234567890abcdeffedcba0987654321",
+ # data=[f"office_data/{i}.txt"], outs=["output.enc"],
+ # )
+
+ validator(
+ benchmark="benchmark://cbench-v1/sha",
+ cmd=f"$BIN $D/office_data/{i}.txt",
+ data=[f"office_data/{i}.txt"],
+ compare_output=False,
+ validate_result=validate_sha_output,
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/stringsearch",
+ cmd=f"$BIN $D/office_data/{i}.txt $D/office_data/{i}.s.txt output.txt",
+ data=[f"office_data/{i}.txt"],
+ outs=["output.txt"],
+ env={
+ # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
+ "ASAN_OPTIONS": "detect_leaks=0",
+ },
+ linkopts=["-lm"],
+ )
+
+ # NOTE(cummins): The stringsearch2 benchmark has a very long execution time.
+ # Use only a single input to keep the validation time reasonable. I have
+ # also observed Segmentation fault on gold standard using 4.txt and 6.txt.
+ if i == 1:
+ validator(
+ benchmark="benchmark://cbench-v1/stringsearch2",
+ cmd=f"$BIN $D/office_data/{i}.txt $D/office_data/{i}.s.txt output.txt",
+ data=[f"office_data/{i}.txt"],
+ outs=["output.txt"],
+ env={
+ # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
+ "ASAN_OPTIONS": "detect_leaks=0",
+ },
+ # TSAN disabled because of extremely long execution leading to
+ # timeouts.
+ sanitizers=[LlvmSanitizer.ASAN, LlvmSanitizer.MSAN, LlvmSanitizer.UBSAN],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/susan",
+ cmd=f"$BIN $D/automotive_susan_data/{i}.pgm output_large.corners.pgm -c",
+ data=[f"automotive_susan_data/{i}.pgm"],
+ outs=["output_large.corners.pgm"],
+ linkopts=["-lm"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/tiff2bw",
+ cmd=f"$BIN $D/consumer_tiff_data/{i}.tif output.tif",
+ data=[f"consumer_tiff_data/{i}.tif"],
+ outs=["output.tif"],
+ linkopts=["-lm"],
+ env={
+ # NOTE(cummins): Benchmark leaks when executed with safe optimizations.
+ "ASAN_OPTIONS": "detect_leaks=0",
+ },
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/tiff2rgba",
+ cmd=f"$BIN $D/consumer_tiff_data/{i}.tif output.tif",
+ data=[f"consumer_tiff_data/{i}.tif"],
+ outs=["output.tif"],
+ linkopts=["-lm"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/tiffdither",
+ cmd=f"$BIN $D/consumer_tiff_data/{i}.bw.tif out.tif",
+ data=[f"consumer_tiff_data/{i}.bw.tif"],
+ outs=["out.tif"],
+ linkopts=["-lm"],
+ )
+
+ validator(
+ benchmark="benchmark://cbench-v1/tiffmedian",
+ cmd=f"$BIN $D/consumer_tiff_data/{i}.nocomp.tif output.tif",
+ data=[f"consumer_tiff_data/{i}.nocomp.tif"],
+ outs=["output.tif"],
+ linkopts=["-lm"],
+ )
+
+ # NOTE(cummins): On macOS the following benchmarks abort with an illegal
+ # hardware instruction error.
+ # if sys.platform != "darwin":
+ # validator(
+ # benchmark="benchmark://cbench-v1/lame",
+ # cmd=f"$BIN $D/consumer_data/{i}.wav output.mp3",
+ # data=[f"consumer_data/{i}.wav"],
+ # outs=["output.mp3"],
+ # compare_output=False,
+ # linkopts=["-lm"],
+ # )
+
+ # NOTE(cummins): Segfault on gold standard.
+ #
+ # validator(
+ # benchmark="benchmark://cbench-v1/ghostscript",
+ # cmd="$BIN -sDEVICE=ppm -dNOPAUSE -dQUIET -sOutputFile=output.ppm -- input.ps",
+ # data=[f"office_data/{i}.ps"],
+ # outs=["output.ppm"],
+ # linkopts=["-lm", "-lz"],
+ # pre_execution_callback=setup_ghostscript_library_files(i),
+ # )
diff --git a/compiler_gym/envs/llvm/datasets/clgen.py b/compiler_gym/envs/llvm/datasets/clgen.py
new file mode 100644
index 000000000..cb000b262
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/clgen.py
@@ -0,0 +1,168 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import io
+import shutil
+import subprocess
+import tarfile
+from pathlib import Path
+from typing import List
+
+from fasteners import InterProcessLock
+
+from compiler_gym.datasets import Benchmark, BenchmarkInitError, TarDatasetWithManifest
+from compiler_gym.datasets.benchmark import BenchmarkWithSource
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation
+from compiler_gym.util.download import download
+from compiler_gym.util.filesystem import atomic_file_write
+from compiler_gym.util.truncate import truncate
+
+
+class CLgenDataset(TarDatasetWithManifest):
+ """The CLgen dataset contains 1000 synthetically generated OpenCL kernels.
+
+ The dataset is from:
+
+ Cummins, Chris, Pavlos Petoumenos, Zheng Wang, and Hugh Leather.
+ "Synthesizing benchmarks for predictive modeling." In 2017 IEEE/ACM
+ International Symposium on Code Generation and Optimization (CGO),
+ pp. 86-99. IEEE, 2017.
+
+ And is available at:
+
+ https://github.com/ChrisCummins/paper-synthesizing-benchmarks
+
+ Installation
+ ------------
+
+ The CLgen dataset consists of OpenCL kernels that are compiled to LLVM-IR
+ on-demand and cached. The first time each benchmark is used there is an
+ overhead of compiling it from OpenCL to bitcode. This is a one-off cost.
+ Compiling OpenCL to bitcode requires third party headers that are downloaded
+ on the first call to :code:`install()`.
+ """
+
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://clgen-v0",
+ description="Synthetically generated OpenCL kernels",
+ references={
+ "Paper": "https://chriscummins.cc/pub/2017-cgo.pdf",
+ "Homepage": "https://github.com/ChrisCummins/clgen",
+ },
+ license="GNU General Public License v3.0",
+ site_data_base=site_data_base,
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-clgen-v0-manifest.bz2"
+ ],
+ manifest_sha256="d2bbc1da5a24a8cb03b604d1d8e59227b33bdfcd964ebe741ca8339f1c8d65cc",
+ tar_urls=[
+ "https://github.com/ChrisCummins/paper-synthesizing-benchmarks/raw/e45b6dffe9998f612624f05a6c4878ab4bcc84ec/data/clgen-1000.tar.bz2"
+ ],
+ tar_sha256="0bbd1b737f2537305e4db09b2971a5fa848b7c3a978bff6b570f45d1a488a72c",
+ strip_prefix="clgen-1000/kernels",
+ tar_compression="bz2",
+ benchmark_file_suffix=".bc",
+ sort_order=sort_order,
+ )
+
+ self._opencl_installed = False
+ self._opencl_headers_installed_marker = (
+ self._site_data_path / ".opencl-installed"
+ )
+ self.libclc_dir = self.site_data_path / "libclc"
+ self.opencl_h_path = self.site_data_path / "opencl.h"
+
+ def install(self):
+ super().install()
+
+ if not self._opencl_installed:
+ self._opencl_installed = self._opencl_headers_installed_marker.is_file()
+
+ if self._opencl_installed:
+ return
+
+ with self._tar_lock, InterProcessLock(self._tar_lockfile):
+ # Repeat install check now that we are in the locked region.
+ if self._opencl_headers_installed_marker.is_file():
+ return
+
+ # Download the libclc headers.
+ shutil.rmtree(self.libclc_dir, ignore_errors=True)
+ self.logger.info("Downloading OpenCL headers")
+ tar_data = io.BytesIO(
+ download(
+ "https://dl.fbaipublicfiles.com/compiler_gym/libclc-v0.tar.bz2",
+ sha256="f1c511f2ac12adf98dcc0fbfc4e09d0f755fa403c18f1fb1ffa5547e1fa1a499",
+ )
+ )
+ with tarfile.open(fileobj=tar_data, mode="r:bz2") as arc:
+ arc.extractall(str(self.site_data_path / "libclc"))
+
+ # Download the OpenCL header.
+ with open(self.opencl_h_path, "wb") as f:
+ f.write(
+ download(
+ "https://github.com/ChrisCummins/clgen/raw/463c0adcd8abcf2432b24df0aca594b77a69e9d3/deeplearning/clgen/data/include/opencl.h",
+ sha256="f95b9f4c8b1d09114e491846d0d41425d24930ac167e024f45dab8071d19f3f7",
+ )
+ )
+
+ self._opencl_headers_installed_marker.touch()
+
+ def benchmark(self, uri: str) -> Benchmark:
+ self.install()
+
+ benchmark_name = uri[len(self.name) + 1 :]
+ if not benchmark_name:
+ raise LookupError(f"No benchmark specified: {uri}")
+
+ # The absolute path of the file, without an extension.
+ path_stem = self.dataset_root / uri[len(self.name) + 1 :]
+
+ bc_path, cl_path = Path(f"{path_stem}.bc"), Path(f"{path_stem}.cl")
+
+ # If the file does not exist, compile it on-demand.
+ if not bc_path.is_file():
+ if not cl_path.is_file():
+ raise LookupError(
+ f"Benchmark not found: {uri} (file not found: {cl_path}, path_stem {path_stem})"
+ )
+
+ # Compile the OpenCL kernel into a bitcode file.
+ with atomic_file_write(bc_path) as tmp_bc_path:
+ compile_command: List[str] = ClangInvocation.from_c_file(
+ cl_path,
+ copt=[
+ "-isystem",
+ str(self.libclc_dir),
+ "-include",
+ str(self.opencl_h_path),
+ "-target",
+ "nvptx64-nvidia-nvcl",
+ "-ferror-limit=1", # Stop on first error.
+ "-w", # No warnings.
+ ],
+ ).command(outpath=tmp_bc_path)
+ self.logger.debug("Exec %s", compile_command)
+ clang = subprocess.Popen(
+ compile_command,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ _, stderr = clang.communicate(timeout=300)
+
+ if clang.returncode:
+ compile_command = " ".join(compile_command)
+ error = truncate(
+ stderr.decode("utf-8"), max_lines=20, max_line_len=20000
+ )
+ raise BenchmarkInitError(
+ f"Compilation job failed!\n"
+ f"Command: {compile_command}\n"
+ f"Error: {error}"
+ )
+
+ return BenchmarkWithSource.create(uri, bc_path, "kernel.cl", cl_path)
diff --git a/compiler_gym/envs/llvm/datasets/csmith.py b/compiler_gym/envs/llvm/datasets/csmith.py
new file mode 100644
index 000000000..dfe51e435
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/csmith.py
@@ -0,0 +1,275 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import io
+import logging
+import subprocess
+import sys
+import tarfile
+import tempfile
+from pathlib import Path
+from threading import Lock
+from typing import Iterable, List
+
+from fasteners import InterProcessLock
+
+from compiler_gym.datasets import Benchmark, BenchmarkSource, Dataset
+from compiler_gym.datasets.benchmark import BenchmarkInitError, BenchmarkWithSource
+from compiler_gym.datasets.dataset import DatasetInitError
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation
+from compiler_gym.util.decorators import memoized_property
+from compiler_gym.util.download import download
+from compiler_gym.util.runfiles_path import transient_cache_path
+from compiler_gym.util.truncate import truncate
+
+# The maximum value for the --seed argument to csmith.
+UINT_MAX = (2 ** 32) - 1
+
+
+class CsmithBenchmark(BenchmarkWithSource):
+ """A CSmith benchmark."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._src = None
+
+ @classmethod
+ def create(cls, uri: str, bitcode: bytes, src: bytes) -> Benchmark:
+ """Create a benchmark from paths."""
+ benchmark = cls.from_file_contents(uri, bitcode)
+ benchmark._src = src # pylint: disable=protected-access
+ return benchmark
+
+ @memoized_property
+ def sources(self) -> Iterable[BenchmarkSource]:
+ return [
+ BenchmarkSource(filename="source.c", contents=self._src),
+ ]
+
+ @property
+ def source(self) -> str:
+ """Return the single source file contents as a string."""
+ return self._src.decode("utf-8")
+
+
+class CsmithBuildError(DatasetInitError):
+ """Error raised if :meth:`CsmithDataset.install()
+ ` fails."""
+
+ def __init__(self, failing_stage: str, stdout: str, stderr: str):
+ install_instructions = {
+ "linux": "sudo apt install g++ m4",
+ "darwin": "brew install m4",
+ }[sys.platform]
+
+ super().__init__(
+ "\n".join(
+ [
+ f"Failed to build Csmith from source, `{failing_stage}` failed.",
+ "You may be missing installation dependencies. Install them using:",
+ f" {install_instructions}",
+ "See https://github.com/csmith-project/csmith#install-csmith for more details",
+ f"--- Start `{failing_stage}` logs: ---\n",
+ stdout,
+ stderr,
+ ]
+ )
+ )
+
+
+class CsmithDataset(Dataset):
+ """A dataset which uses Csmith to generate programs.
+
+ Csmith is a tool that can generate random conformant C99 programs. It is
+ described in the publication:
+
+ Yang, Xuejun, Yang Chen, Eric Eide, and John Regehr. "Finding and
+ understanding bugs in C compilers." In Proceedings of the 32nd ACM
+ SIGPLAN conference on Programming Language Design and Implementation
+ (PLDI), pp. 283-294. 2011.
+
+ For up-to-date information about Csmith, see:
+
+ https://embed.cs.utah.edu/csmith/
+
+ Note that Csmith is a tool that is used to find errors in compilers. As
+ such, there is a higher likelihood that the benchmark cannot be used for an
+ environment and that :meth:`env.reset()
+ ` will raise :class:`BenchmarkInitError
+ `.
+
+ Installation
+ ------------
+
+ Using the CsmithDataset requires building the Csmith binary from source.
+ This is done automatically on the first call to :code:`install()`. Building
+ Csmith requires a working C++ toolchain. Install the required dependencies
+ using: :code:`sudo apt install -y g++ m4` on Linux, or :code:`brew install
+ m4` on macOS. :class:`DatasetInitError
+ ` is raised if compilation fails.
+ See the `Csmith repo
+ `_ for further
+ details.
+ """
+
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="generator://csmith-v0",
+ description="Random conformant C99 programs",
+ references={
+ "Paper": "http://web.cse.ohio-state.edu/~rountev.1/5343/pdf/pldi11.pdf",
+ "Homepage": "https://embed.cs.utah.edu/csmith/",
+ },
+ license="BSD",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ benchmark_class=CsmithBenchmark,
+ )
+ self.csmith_path = self.site_data_path / "bin" / "csmith"
+ csmith_include_dir = self.site_data_path / "include" / "csmith-2.3.0"
+
+ self._installed = False
+ self._build_lock = Lock()
+ self._build_lockfile = self.site_data_path / ".build.LOCK"
+ self._build_markerfile = self.site_data_path / ".built"
+
+ # The command that is used to compile an LLVM-IR bitcode file from a
+ # Csmith input. Reads from stdin, writes to stdout.
+ self.clang_compile_command: List[str] = ClangInvocation.from_c_file(
+ "-", # Read from stdin.
+ copt=[
+ "-xc",
+ "-ferror-limit=1", # Stop on first error.
+ "-w", # No warnings.
+ f"-I{csmith_include_dir}", # Include the Csmith headers.
+ ],
+ ).command(
+ outpath="-"
+ ) # Write to stdout.
+
+ @property
+ def installed(self) -> bool:
+ # Fast path for repeated checks to 'installed' without a disk op.
+ if not self._installed:
+ self._installed = self._build_markerfile.is_file()
+ return self._installed
+
+ def install(self) -> None:
+ """Download and build the Csmith binary."""
+ super().install()
+
+ if self.installed:
+ return
+
+ with self._build_lock, InterProcessLock(self._build_lockfile):
+ # Repeat the check to see if we have already installed the dataset
+ # now that we have acquired the lock.
+ if not self.installed:
+ self.logger.info("Downloading and building Csmith")
+ self._build_csmith(self.site_data_path, self.logger)
+ self._build_markerfile.touch()
+
+ @staticmethod
+ def _build_csmith(install_root: Path, logger: logging.Logger):
+ """Download, build, and install Csmith to the given directory."""
+ tar_data = io.BytesIO(
+ download(
+ urls=[
+ "https://github.com/csmith-project/csmith/archive/refs/tags/csmith-2.3.0.tar.gz",
+ ],
+ sha256="ba871c1e5a05a71ecd1af514fedba30561b16ee80b8dd5ba8f884eaded47009f",
+ )
+ )
+ # Csmith uses a standard `configure` + `make install` build process.
+ with tempfile.TemporaryDirectory(
+ dir=transient_cache_path("."), prefix="csmith-"
+ ) as d:
+ with tarfile.open(fileobj=tar_data, mode="r:gz") as arc:
+ arc.extractall(d)
+
+ # The path of the extracted sources.
+ src_dir = Path(d) / "csmith-csmith-2.3.0"
+
+ logger.debug("Configuring Csmith at %s", d)
+ configure = subprocess.Popen(
+ ["./configure", f"--prefix={install_root}"],
+ cwd=src_dir,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ )
+ stdout, stderr = configure.communicate(timeout=600)
+ if configure.returncode:
+ raise CsmithBuildError("./configure", stdout, stderr)
+
+ logger.debug("Installing Csmith to %s", install_root)
+ make = subprocess.Popen(
+ ["make", "-j", "install"],
+ cwd=src_dir,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ )
+ stdout, stderr = make.communicate(timeout=600)
+ if make.returncode:
+ raise CsmithBuildError("make install", stdout, stderr)
+
+ @property
+ def size(self) -> int:
+ # Actually 2^32 - 1, but practically infinite for all intents and
+ # purposes.
+ return float("inf")
+
+ def benchmark_uris(self) -> Iterable[str]:
+ return (f"{self.name}/{i}" for i in range(UINT_MAX))
+
+ def benchmark(self, uri: str) -> CsmithBenchmark:
+ return self.benchmark_from_seed(int(uri.split("/")[-1]))
+
+ def benchmark_from_seed(self, seed: int) -> CsmithBenchmark:
+ """Get a benchmark from a uint32 seed.
+
+ :param seed: A number in the range 0 <= n < 2^32.
+
+ :return: A benchmark instance.
+ """
+ self.install()
+
+ # Run csmith with the given seed and pipe the output to clang to
+ # assemble a bitcode.
+ self.logger.debug("Exec csmith --seed %d", seed)
+ csmith = subprocess.Popen(
+ [str(self.csmith_path), "--seed", str(seed)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ )
+
+ # Generate the C source.
+ src, stderr = csmith.communicate(timeout=300)
+ if csmith.returncode:
+ error = truncate(stderr.decode("utf-8"), max_lines=20, max_line_len=100)
+ raise OSError(f"Csmith failed with seed {seed}\nError: {error}")
+
+ # Compile to IR.
+ clang = subprocess.Popen(
+ self.clang_compile_command,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ stdout, stderr = clang.communicate(src, timeout=300)
+
+ if csmith.returncode:
+ raise OSError(f"Csmith failed with seed {seed}")
+ if clang.returncode:
+ compile_cmd = " ".join(self.clang_compile_command)
+ error = truncate(stderr.decode("utf-8"), max_lines=20, max_line_len=100)
+ raise BenchmarkInitError(
+ f"Compilation job failed!\n"
+ f"Csmith seed: {seed}\n"
+ f"Command: {compile_cmd}\n"
+ f"Error: {error}"
+ )
+
+ return self.benchmark_class.create(f"{self.name}/{seed}", stdout, src)
diff --git a/compiler_gym/envs/llvm/datasets/llvm_stress.py b/compiler_gym/envs/llvm/datasets/llvm_stress.py
new file mode 100644
index 000000000..02d948bb2
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/llvm_stress.py
@@ -0,0 +1,86 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import subprocess
+from pathlib import Path
+from typing import Iterable
+
+from compiler_gym.datasets import Benchmark, Dataset
+from compiler_gym.datasets.benchmark import BenchmarkInitError
+from compiler_gym.third_party import llvm
+
+# The maximum value for the --seed argument to llvm-stress.
+UINT_MAX = (2 ** 32) - 1
+
+
+class LlvmStressDataset(Dataset):
+ """A dataset which uses llvm-stress to generate programs.
+
+ `llvm-stress `_ is a
+ tool for generating random LLVM-IR files.
+
+ This dataset forces reproducible results by setting the input seed to the
+ generator. The benchmark's URI is the seed, e.g.
+ "generator://llvm-stress-v0/10" is the benchmark generated by llvm-stress
+ using seed 10. The total number of unique seeds is 2^32 - 1.
+
+ Note that llvm-stress is a tool that is used to find errors in LLVM. As
+ such, there is a higher likelihood that the benchmark cannot be used for an
+ environment and that :meth:`env.reset()
+ ` will raise
+ :class:`BenchmarkInitError `.
+ """
+
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="generator://llvm-stress-v0",
+ description="Randomly generated LLVM-IR",
+ references={
+ "Documentation": "https://llvm.org/docs/CommandGuide/llvm-stress.html"
+ },
+ license="Apache License v2.0 with LLVM Exceptions",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+ @property
+ def size(self) -> int:
+ # Actually 2^32 - 1, but practically infinite for all intents and
+ # purposes.
+ return float("inf")
+
+ def benchmark_uris(self) -> Iterable[str]:
+ return (f"{self.name}/{i}" for i in range(UINT_MAX))
+
+ def benchmark(self, uri: str) -> Benchmark:
+ return self.benchmark_from_seed(int(uri.split("/")[-1]))
+
+ def benchmark_from_seed(self, seed: int) -> Benchmark:
+ """Get a benchmark from a uint32 seed.
+
+ :param seed: A number in the range 0 <= n < 2^32.
+
+ :return: A benchmark instance.
+ """
+ self.install()
+
+ # Run llvm-stress with the given seed and pipe the output to llvm-as to
+ # assemble a bitcode.
+ llvm_stress = subprocess.Popen(
+ [str(llvm.llvm_stress_path()), f"--seed={seed}"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ llvm_as = subprocess.Popen(
+ [str(llvm.llvm_as_path()), "-"],
+ stdin=llvm_stress.stdout,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+
+ stdout, _ = llvm_as.communicate(timeout=60)
+ if llvm_stress.returncode or llvm_as.returncode:
+ raise BenchmarkInitError("Failed to generate benchmark")
+
+ return Benchmark.from_file_contents(f"{self.name}/{seed}", stdout)
diff --git a/compiler_gym/envs/llvm/datasets/poj104.py b/compiler_gym/envs/llvm/datasets/poj104.py
new file mode 100644
index 000000000..39fc8e9c1
--- /dev/null
+++ b/compiler_gym/envs/llvm/datasets/poj104.py
@@ -0,0 +1,193 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import subprocess
+import sys
+from concurrent.futures import as_completed
+from pathlib import Path
+from typing import Optional
+
+from compiler_gym.datasets import Benchmark, BenchmarkInitError, TarDatasetWithManifest
+from compiler_gym.datasets.benchmark import BenchmarkWithSource
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation
+from compiler_gym.util import thread_pool
+from compiler_gym.util.download import download
+from compiler_gym.util.filesystem import atomic_file_write
+from compiler_gym.util.truncate import truncate
+
+
+class POJ104Dataset(TarDatasetWithManifest):
+ """The POJ-104 dataset contains 52000 C++ programs implementing 104
+ different algorithms with 500 examples of each.
+
+ The dataset is from:
+
+ Lili Mou, Ge Li, Lu Zhang, Tao Wang, Zhi Jin. "Convolutional neural
+ networks over tree structures for programming language processing." To
+ appear in Proceedings of 30th AAAI Conference on Artificial
+ Intelligence, 2016.
+
+ And is available at:
+
+ https://sites.google.com/site/treebasedcnn/
+ """
+
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ manifest_url, manifest_sha256 = {
+ "darwin": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v1-macos-manifest.bz2",
+ "74db443f225478933dd0adf3f821fd4e615089eeaa90596c19d9d1af7006a801",
+ ),
+ "linux": (
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v1-linux-manifest.bz2",
+ "ee6253ee826e171816105e76fa78c0d3cbd319ef66e10da4bcf9cf8a78e12ab9",
+ ),
+ }[sys.platform]
+ super().__init__(
+ name="benchmark://poj104-v1",
+ tar_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v1.tar.gz",
+ "https://drive.google.com/u/0/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU&export=download",
+ ],
+ tar_sha256="c0b8ef3ee9c9159c882dc9337cb46da0e612a28e24852a83f8a1cd68c838f390",
+ tar_compression="gz",
+ manifest_urls=[manifest_url],
+ manifest_sha256=manifest_sha256,
+ references={
+ "Paper": "https://ojs.aaai.org/index.php/AAAI/article/download/10139/9998",
+ "Homepage": "https://sites.google.com/site/treebasedcnn/",
+ },
+ license="BSD 3-Clause",
+ strip_prefix="ProgramData",
+ description="Solutions to programming programs",
+ benchmark_file_suffix=".txt",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ )
+
+ def benchmark(self, uri: Optional[str] = None) -> Benchmark:
+ self.install()
+ if uri is None or len(uri) <= len(self.name) + 1:
+ return self._get_benchmark_by_index(self.random.integers(self.size))
+
+ # The absolute path of the file, without an extension.
+ path_stem = self.dataset_root / uri[len(self.name) + 1 :]
+
+ # If the file does not exist, compile it on-demand.
+ bitcode_path = Path(f"{path_stem}.bc")
+ cc_file_path = Path(f"{path_stem}.txt")
+
+ if not bitcode_path.is_file():
+ if not cc_file_path.is_file():
+ raise LookupError(
+ f"Benchmark not found: {uri} (file not found: {cc_file_path})"
+ )
+
+ # Load the C++ source into memory and pre-process it.
+ with open(cc_file_path) as f:
+ src = self.preprocess_poj104_source(f.read())
+
+ # Compile the C++ source into a bitcode file.
+ with atomic_file_write(bitcode_path) as tmp_bitcode_path:
+ compile_cmd = ClangInvocation.from_c_file(
+ "-",
+ copt=[
+ "-xc++",
+ "-ferror-limit=1", # Stop on first error.
+ "-w", # No warnings.
+ # Some of the programs use the gets() function that was
+ # deprecated in C++11 and removed in C++14.
+ "-std=c++11",
+ ],
+ ).command(outpath=tmp_bitcode_path)
+ self.logger.debug("Exec %s", compile_cmd)
+ clang = subprocess.Popen(
+ compile_cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ _, stderr = clang.communicate(src.encode("utf-8"), timeout=300)
+
+ if clang.returncode:
+ compile_cmd = " ".join(compile_cmd)
+ error = truncate(stderr.decode("utf-8"), max_lines=20, max_line_len=100)
+ raise BenchmarkInitError(
+ f"Compilation job failed!\n"
+ f"Command: {compile_cmd}\n"
+ f"Error: {error}"
+ )
+ if not bitcode_path.is_file():
+ raise BenchmarkInitError(
+ f"Compilation job failed to produce output file!\nCommand: {compile_cmd}"
+ )
+
+ return BenchmarkWithSource.create(uri, bitcode_path, "source.cc", cc_file_path)
+
+ @staticmethod
+ def preprocess_poj104_source(src: str) -> str:
+ """Pre-process a POJ-104 C++ source file for compilation."""
+ # Clean up declaration of main function. Many are missing a return type
+ # declaration, or use an incorrect void return type.
+ src = src.replace("void main", "int main")
+ src = src.replace("\nmain", "int main")
+ if src.startswith("main"):
+ src = f"int {src}"
+
+ # Pull in the standard library.
+ if sys.platform == "linux":
+ header = "#include \n" "using namespace std;\n"
+ else:
+ # Download a bits/stdc++ implementation for macOS.
+ header = download(
+ "https://github.com/raw/tekfyl/bits-stdc-.h-for-mac/e1193f4470514d82ea19c3cc1357116fadaa2a4e/stdc%2B%2B.h",
+ sha256="b4d9b031d56d89a2b58b5ed80fa9943aa92420d6aed0835747c9a5584469afeb",
+ ).decode("utf-8")
+
+ # These defines provide values for commonly undefined symbols. Defining
+ # these macros increases the number of POJ-104 programs that compile
+ # from 49,302 to 49,821 (+519) on linux.
+ defines = "#define LEN 128\n" "#define MAX_LENGTH 1024\n" "#define MAX 1024\n"
+
+ return header + defines + src
+
+ def compile_all(self):
+ n = self.size
+ executor = thread_pool.get_thread_pool_executor()
+ # Since the dataset is lazily compiled, simply iterating over the full
+ # set of URIs will compile everything. Do this in parallel.
+ futures = (
+ executor.submit(self.benchmark, uri) for uri in self.benchmark_uris()
+ )
+ for i, future in enumerate(as_completed(futures), start=1):
+ future.result()
+ print(
+ f"\r\033[KCompiled {i} of {n} programs ({i/n:.2%} complete)",
+ flush=True,
+ end="",
+ )
+
+
+class POJ104LegacyDataset(TarDatasetWithManifest):
+ def __init__(self, site_data_base: Path, sort_order: int = 0):
+ super().__init__(
+ name="benchmark://poj104-v0",
+ tar_urls="https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v0.tar.bz2",
+ tar_sha256="6254d629887f6b51efc1177788b0ce37339d5f3456fb8784415ed3b8c25cce27",
+ manifest_urls=[
+ "https://dl.fbaipublicfiles.com/compiler_gym/llvm_bitcodes-10.0.0-poj104-v0-manifest.bz2"
+ ],
+ manifest_sha256="ac3eaaad7d2878d871ed2b5c72a3f39c058ea6694989af5c86cd162414db750b",
+ references={
+ "Paper": "https://ojs.aaai.org/index.php/AAAI/article/download/10139/9998",
+ "Homepage": "https://sites.google.com/site/treebasedcnn/",
+ },
+ license="BSD 3-Clause",
+ strip_prefix="poj104-v0",
+ description="Solutions to programming programs",
+ benchmark_file_suffix=".bc",
+ site_data_base=site_data_base,
+ sort_order=sort_order,
+ deprecated="Please update to benchmark://poj104-v1.",
+ )
diff --git a/compiler_gym/envs/llvm/legacy_datasets.py b/compiler_gym/envs/llvm/legacy_datasets.py
index fb12dc868..fe1090b88 100644
--- a/compiler_gym/envs/llvm/legacy_datasets.py
+++ b/compiler_gym/envs/llvm/legacy_datasets.py
@@ -27,7 +27,7 @@
from compiler_gym.util.download import download
from compiler_gym.util.runfiles_path import cache_path, site_data_path
from compiler_gym.util.timer import Timer
-from compiler_gym.validation_result import ValidationError
+from compiler_gym.validation_error import ValidationError
_CBENCH_DATA_URL = (
"https://dl.fbaipublicfiles.com/compiler_gym/cBench-v0-runtime-data.tar.bz2"
@@ -419,7 +419,7 @@ def _make_cBench_validator(
def validator_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
"""The validation callback."""
with _CBENCH_DOWNLOAD_THREAD_LOCK:
- with fasteners.InterProcessLock(cache_path("cBench-v1-runtime-data.LOCK")):
+ with fasteners.InterProcessLock(cache_path(".cBench-v1-runtime-data.lock")):
download_cBench_runtime_data()
cbench_data = site_data_path("llvm/cBench-v1-runtime-data/runtime_data")
@@ -643,10 +643,11 @@ def get_llvm_benchmark_validation_callback(
If there is no valid callback, returns :code:`None`.
- :param env: An :class:`LlvmEnv` instance.
- :return: An optional callback that takes an :class:`LlvmEnv` instance as
- argument and returns an optional string containing a validation error
- message.
+ :param env: An :class:`LlvmEnv ` instance.
+
+ :return: An optional callback that takes an :class:`LlvmEnv
+ ` instance as argument and returns an
+ optional string containing a validation error message.
"""
validators = VALIDATORS.get(env.benchmark)
diff --git a/compiler_gym/envs/llvm/benchmarks.py b/compiler_gym/envs/llvm/llvm_benchmark.py
similarity index 73%
rename from compiler_gym/envs/llvm/benchmarks.py
rename to compiler_gym/envs/llvm/llvm_benchmark.py
index 0416a85b8..31d510c3e 100644
--- a/compiler_gym/envs/llvm/benchmarks.py
+++ b/compiler_gym/envs/llvm/llvm_benchmark.py
@@ -8,16 +8,16 @@
import subprocess
import sys
import tempfile
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
from datetime import datetime
-from multiprocessing import cpu_count
from pathlib import Path
from signal import Signals
from typing import Iterable, List, Optional, Union
-from compiler_gym.service.proto import Benchmark, File
+from compiler_gym.datasets import Benchmark, BenchmarkInitError
from compiler_gym.third_party import llvm
-from compiler_gym.util.runfiles_path import cache_path
+from compiler_gym.util.runfiles_path import transient_cache_path
+from compiler_gym.util.thread_pool import get_thread_pool_executor
def _communicate(process, input=None, timeout=None):
@@ -137,6 +137,42 @@ def command(self, outpath: Path) -> List[str]:
return cmd
+ @classmethod
+ def from_c_file(
+ cls,
+ path: Path,
+ copt: Optional[List[str]] = None,
+ system_includes: bool = True,
+ timeout: int = 600,
+ ) -> "ClangInvocation":
+ copt = copt or []
+ # NOTE(cummins): There is some discussion about the best way to create a
+ # bitcode that is unoptimized yet does not hinder downstream
+ # optimization opportunities. Here we are using a configuration based on
+ # -O1 in which we prevent the -O1 optimization passes from running. This
+ # is because LLVM produces different function attributes dependening on
+ # the optimization level. E.g. "-O0 -Xclang -disable-llvm-optzns -Xclang
+ # -disable-O0-optnone" will generate code with "noinline" attributes set
+ # on the functions, wheras "-Oz -Xclang -disable-llvm-optzns" will
+ # generate functions with "minsize" and "optsize" attributes set.
+ #
+ # See also:
+ #
+ #
+ DEFAULT_COPT = [
+ "-O1",
+ "-Xclang",
+ "-disable-llvm-passes",
+ "-Xclang",
+ "-disable-llvm-optzns",
+ ]
+
+ return cls(
+ DEFAULT_COPT + copt + [str(path)],
+ system_includes=system_includes,
+ timeout=timeout,
+ )
+
def _run_command(cmd: List[str], timeout: int):
process = subprocess.Popen(
@@ -151,7 +187,7 @@ def _run_command(cmd: List[str], timeout: int):
returncode = f"{returncode} ({Signals(abs(returncode)).name})"
except ValueError:
pass
- raise OSError(
+ raise BenchmarkInitError(
f"Compilation job failed with returncode {returncode}\n"
f"Command: {' '.join(cmd)}\n"
f"Stderr: {stderr.strip()}"
@@ -171,9 +207,9 @@ def make_benchmark(
For single-source C/C++ programs, you can pass the path of the source file:
- >>> benchmark = make_benchmark('my_app.c')
- >>> env = gym.make("llvm-v0")
- >>> env.reset(benchmark=benchmark)
+ >>> benchmark = make_benchmark('my_app.c')
+ >>> env = gym.make("llvm-v0")
+ >>> env.reset(benchmark=benchmark)
The clang invocation used is roughly equivalent to:
@@ -184,81 +220,76 @@ def make_benchmark(
Additional compile-time arguments to clang can be provided using the
:code:`copt` argument:
- >>> benchmark = make_benchmark('/path/to/my_app.cpp', copt=['-O2'])
+ >>> benchmark = make_benchmark('/path/to/my_app.cpp', copt=['-O2'])
If you need more fine-grained control over the options, you can directly
- construct a :class:`ClangInvocation `
- to pass a list of arguments to clang:
+ construct a :class:`ClangInvocation
+ ` to pass a list of arguments to
+ clang:
- >>> benchmark = make_benchmark(
- ClangInvocation(['/path/to/my_app.c'], timeout=10)
- )
+ >>> benchmark = make_benchmark(
+ ClangInvocation(['/path/to/my_app.c'], timeout=10)
+ )
For multi-file programs, pass a list of inputs that will be compiled
separately and then linked to a single module:
- >>> benchmark = make_benchmark([
- 'main.c',
- 'lib.cpp',
- 'lib2.bc',
- ])
+ >>> benchmark = make_benchmark([
+ 'main.c',
+ 'lib.cpp',
+ 'lib2.bc',
+ ])
If you already have prepared bitcode files, those can be linked and used
directly:
- >>> benchmark = make_benchmark([
- 'bitcode1.bc',
- 'bitcode2.bc',
- ])
+ >>> benchmark = make_benchmark([
+ 'bitcode1.bc',
+ 'bitcode2.bc',
+ ])
+
+ Text-format LLVM assembly can also be used:
+
+ >>> benchmark = make_benchmark('module.ll')
.. note::
+
LLVM bitcode compatibility is
`not guaranteed `_,
so you must ensure that any precompiled bitcodes are compatible with the
LLVM version used by CompilerGym, which can be queried using
- :func:`LlvmEnv.compiler_version `.
+ :func:`env.compiler_version `.
:param inputs: An input, or list of inputs.
+
:param copt: A list of command line options to pass to clang when compiling
source files.
+
:param system_includes: Whether to include the system standard libraries
during compilation jobs. This requires a system toolchain. See
:func:`get_system_includes`.
+
:param timeout: The maximum number of seconds to allow clang to run before
terminating.
- :return: A :code:`Benchmark` message.
+
+ :return: A :code:`Benchmark` instance.
+
:raises FileNotFoundError: If any input sources are not found.
+
:raises TypeError: If the inputs are of unsupported types.
+
:raises OSError: If a compilation job fails.
- :raises TimeoutExpired: If a compilation job exceeds :code:`timeout` seconds.
+
+ :raises TimeoutExpired: If a compilation job exceeds :code:`timeout`
+ seconds.
"""
copt = copt or []
bitcodes: List[Path] = []
clang_jobs: List[ClangInvocation] = []
+ ll_paths: List[Path] = []
def _add_path(path: Path):
- # NOTE(cummins): There is some discussion about the best way to create a
- # bitcode that is unoptimized yet does not hinder downstream
- # optimization opportunities. Here we are using a configuration based on
- # -O1 in which we prevent the -O1 optimization passes from running. This
- # is because LLVM produces different function attributes dependening on
- # the optimization level. E.g. "-O0 -Xclang -disable-llvm-optzns -Xclang
- # -disable-O0-optnone" will generate code with "noinline" attributes set
- # on the functions, wheras "-Oz -Xclang -disable-llvm-optzns" will
- # generate functions with "minsize" and "optsize" attributes set.
- #
- # See also:
- #
- #
- DEFAULT_COPT = [
- "-O1",
- "-Xclang",
- "-disable-llvm-passes",
- "-Xclang",
- "-disable-llvm-optzns",
- ]
-
if not path.is_file():
raise FileNotFoundError(path)
@@ -266,12 +297,12 @@ def _add_path(path: Path):
bitcodes.append(path)
elif path.suffix in {".c", ".cxx", ".cpp", ".cc"}:
clang_jobs.append(
- ClangInvocation(
- [str(path)] + DEFAULT_COPT + copt,
- system_includes=system_includes,
- timeout=timeout,
+ ClangInvocation.from_c_file(
+ path, copt=copt, system_includes=system_includes, timeout=timeout
)
)
+ elif path.suffix == ".ll":
+ ll_paths.append(path)
else:
raise ValueError(f"Unrecognized file type: {path.name}")
@@ -290,40 +321,66 @@ def _add_path(path: Path):
else:
raise TypeError(f"Invalid input type: {type(input).__name__}")
- if not bitcodes and not clang_jobs:
- raise ValueError("No inputs")
-
# Shortcut if we only have a single pre-compiled bitcode.
if len(bitcodes) == 1 and not clang_jobs:
bitcode = bitcodes[0]
- return Benchmark(
- uri=f"file:///{bitcode}", program=File(uri=f"file:///{bitcode}")
- )
+ return Benchmark.from_file(uri=f"file:///{bitcode}", path=bitcode)
- tmpdir_root = cache_path(".")
+ tmpdir_root = transient_cache_path(".")
tmpdir_root.mkdir(exist_ok=True, parents=True)
- with tempfile.TemporaryDirectory(dir=tmpdir_root) as d:
+ with tempfile.TemporaryDirectory(
+ dir=tmpdir_root, prefix="llvm-make_benchmark-"
+ ) as d:
working_dir = Path(d)
- # Run the clang invocations in parallel.
clang_outs = [
- working_dir / f"out-{i}.bc" for i in range(1, len(clang_jobs) + 1)
+ working_dir / f"clang-out-{i}.bc" for i in range(1, len(clang_jobs) + 1)
+ ]
+ llvm_as_outs = [
+ working_dir / f"llvm-as-out-{i}.bc" for i in range(1, len(ll_paths) + 1)
]
- with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
- futures = (
+
+ # Run the clang and llvm-as invocations in parallel. Avoid running this
+ # code path if possible as get_thread_pool_executor() requires locking.
+ if clang_jobs or ll_paths:
+ llvm_as_path = str(llvm.llvm_as_path())
+ executor = get_thread_pool_executor()
+
+ llvm_as_commands = [
+ [llvm_as_path, str(ll_path), "-o", bc_path]
+ for ll_path, bc_path in zip(ll_paths, llvm_as_outs)
+ ]
+
+ # Fire off the clang and llvm-as jobs.
+ futures = [
executor.submit(_run_command, job.command(out), job.timeout)
for job, out in zip(clang_jobs, clang_outs)
- )
- list(future.result() for future in as_completed(futures))
+ ] + [
+ executor.submit(_run_command, command, timeout)
+ for command in llvm_as_commands
+ ]
- # Check that the expected files were generated.
- for i, b in enumerate(clang_outs):
- if not b.is_file():
- raise OSError(
- f"Clang invocation failed to produce a file: {' '.join(clang_jobs[i].command(clang_outs[i]))}"
- )
+ # Block until finished.
+ list(future.result() for future in as_completed(futures))
- if len(bitcodes + clang_outs) > 1:
+ # Check that the expected files were generated.
+ for clang_job, bc_path in zip(clang_jobs, clang_outs):
+ if not bc_path.is_file():
+ raise BenchmarkInitError(
+ f"clang failed: {' '.join(clang_job.command(bc_path))}"
+ )
+ for command, bc_path in zip(llvm_as_commands, llvm_as_outs):
+ if not bc_path.is_file():
+ raise BenchmarkInitError(f"llvm-as failed: {command}")
+
+ all_outs = bitcodes + clang_outs + llvm_as_outs
+ if not all_outs:
+ raise ValueError("No inputs")
+ elif len(all_outs) == 1:
+ # We only have a single bitcode so read it.
+ with open(str(all_outs[0]), "rb") as f:
+ bitcode = f.read()
+ else:
# Link all of the bitcodes into a single module.
llvm_link_cmd = [str(llvm.llvm_link_path()), "-o", "-"] + [
str(path) for path in bitcodes + clang_outs
@@ -333,15 +390,10 @@ def _add_path(path: Path):
)
bitcode, stderr = _communicate(llvm_link, timeout=timeout)
if llvm_link.returncode:
- raise OSError(
+ raise BenchmarkInitError(
f"Failed to link LLVM bitcodes with error: {stderr.decode('utf-8')}"
)
- else:
- # We only have a single bitcode so read it.
- with open(str(list(bitcodes + clang_outs)[0]), "rb") as f:
- bitcode = f.read()
- timestamp = datetime.now().strftime(f"%Y%m%HT%H%M%S-{random.randrange(16**4):04x}")
- return Benchmark(
- uri=f"benchmark://user/{timestamp}", program=File(contents=bitcode)
- )
+ timestamp = datetime.now().strftime("%Y%m%HT%H%M%S")
+ uri = f"benchmark://user/{timestamp}-{random.randrange(16**4):04x}"
+ return Benchmark.from_file_contents(uri, bitcode)
diff --git a/compiler_gym/envs/llvm/llvm_env.py b/compiler_gym/envs/llvm/llvm_env.py
index 284cf7aa3..02cab0459 100644
--- a/compiler_gym/envs/llvm/llvm_env.py
+++ b/compiler_gym/envs/llvm/llvm_env.py
@@ -7,18 +7,16 @@
import os
import shutil
from pathlib import Path
-from typing import Callable, Iterable, List, Optional, Union, cast
+from typing import Iterable, List, Optional, Union, cast
import numpy as np
from gym.spaces import Box
from gym.spaces import Dict as DictSpace
+from compiler_gym.datasets import Benchmark, BenchmarkInitError, Dataset
from compiler_gym.envs.compiler_env import CompilerEnv
-from compiler_gym.envs.llvm.benchmarks import make_benchmark
-from compiler_gym.envs.llvm.legacy_datasets import (
- LLVM_DATASETS,
- get_llvm_benchmark_validation_callback,
-)
+from compiler_gym.envs.llvm.datasets import get_llvm_datasets
+from compiler_gym.envs.llvm.llvm_benchmark import ClangInvocation, make_benchmark
from compiler_gym.envs.llvm.llvm_rewards import (
BaselineImprovementNormalizedReward,
CostFunctionReward,
@@ -29,8 +27,7 @@
from compiler_gym.third_party.inst2vec import Inst2vecEncoder
from compiler_gym.third_party.llvm import download_llvm_files
from compiler_gym.third_party.llvm.instcount import INST_COUNT_FEATURE_NAMES
-from compiler_gym.util.runfiles_path import runfiles_path, site_data_path
-from compiler_gym.validation_result import ValidationError
+from compiler_gym.util.runfiles_path import runfiles_path
_ACTIONS_LIST = Path(
runfiles_path("compiler_gym/envs/llvm/service/passes/actions_list.txt")
@@ -58,35 +55,58 @@ def _read_list_file(path: Path) -> Iterable[str]:
_FLAGS = dict(zip(_ACTIONS, _read_list_file(_FLAGS_LIST)))
_DESCRIPTIONS = dict(zip(_ACTIONS, _read_list_file(_DESCRIPTIONS_LIST)))
-# TODO(github.com/facebookresearch/CompilerGym/issues/122): Lazily instantiate
-# inst2vec encoder.
_INST2VEC_ENCODER = Inst2vecEncoder()
+_LLVM_DATASETS: Optional[List[Dataset]] = None
+
+
+def _get_llvm_datasets(site_data_base: Optional[Path] = None) -> Iterable[Dataset]:
+ """Get the LLVM datasets. Use a singleton value when site_data_base is the
+ default value.
+ """
+ global _LLVM_DATASETS
+ if site_data_base is None:
+ if _LLVM_DATASETS is None:
+ _LLVM_DATASETS = list(get_llvm_datasets(site_data_base=site_data_base))
+ return _LLVM_DATASETS
+ return get_llvm_datasets(site_data_base=site_data_base)
+
+
class LlvmEnv(CompilerEnv):
"""A specialized CompilerEnv for LLVM.
- This extends the default :class:`CompilerEnv` environment, adding extra LLVM
- functionality. Specifically, the actions use the
- :class:`CommandlineFlag ` space, which
- is a type of :code:`Discrete` space that provides additional documentation
- about each action, and the
- :meth:`LlvmEnv.commandline() ` method
- can be used to produce an equivalent LLVM opt invocation for the current
- environment state.
+ This extends the default :class:`CompilerEnv
+ ` environment, adding extra LLVM
+ functionality. Specifically, the actions use the :class:`CommandlineFlag
+ ` space, which is a type of
+ :code:`Discrete` space that provides additional documentation about each
+ action, and the :meth:`LlvmEnv.commandline()
+ ` method can be used to produce an
+ equivalent LLVM opt invocation for the current environment state.
:ivar actions: The list of actions that have been performed since the
previous call to :func:`reset`.
+
:vartype actions: List[int]
"""
- def __init__(self, *args, **kwargs):
+ def __init__(
+ self,
+ *args,
+ benchmark: Optional[Union[str, Benchmark]] = None,
+ datasets_site_path: Optional[Path] = None,
+ **kwargs,
+ ):
# First perform a one-time download of LLVM binaries that are needed by
# the LLVM service and are not included by the pip-installed package.
download_llvm_files()
super().__init__(
*args,
**kwargs,
+ # Set a default benchmark for use.
+ benchmark=benchmark or "cbench-v1/qsort",
+ datasets=_get_llvm_datasets(site_data_base=datasets_site_path),
rewards=[
CostFunctionReward(
id="IrInstructionCount",
@@ -164,13 +184,6 @@ def __init__(self, *args, **kwargs):
),
],
)
- self.datasets_site_path = site_data_path("llvm/10.0.0/bitcode_benchmarks")
-
- # Register the LLVM datasets.
- self.datasets_site_path.mkdir(parents=True, exist_ok=True)
- self.inactive_datasets_site_path.mkdir(parents=True, exist_ok=True)
- for dataset in LLVM_DATASETS:
- self.register_dataset(dataset)
self.inst2vec = _INST2VEC_ENCODER
@@ -279,27 +292,112 @@ def __init__(self, *args, **kwargs):
)
def reset(self, *args, **kwargs):
- # The BenchmarkFactory::getBenchmark() method raises an error if there
- # are no benchmarks to select from. Install the cBench dataset as a
- # fallback.
- #
- # TODO(github.com/facebookresearch/CompilerGym/issues/45): Remove this
- # once the dataset API has been refactored so that service-side datasets
- # are no longer an issue.
try:
return super().reset(*args, **kwargs)
- except FileNotFoundError:
- self.logger.warning(
- "reset() called on servie with no benchmarks available. "
- "Installing cBench-v1"
+ except ValueError as e:
+ # Catch and re-raise a compilation error with a more informative
+ # error type.
+ if "Failed to compute .text size cost" in str(e):
+ raise BenchmarkInitError(
+ f"Failed to initialize benchmark {self._benchmark_in_use.uri}: {e}"
+ ) from e
+ raise
+
+ def make_benchmark(
+ self,
+ inputs: Union[
+ str, Path, ClangInvocation, List[Union[str, Path, ClangInvocation]]
+ ],
+ copt: Optional[List[str]] = None,
+ system_includes: bool = True,
+ timeout: int = 600,
+ ) -> Benchmark:
+ """Create a benchmark for use with this environment.
+
+ This function takes one or more inputs and uses them to create a
+ benchmark that can be passed to :meth:`compiler_gym.envs.LlvmEnv.reset`.
+
+ For single-source C/C++ programs, you can pass the path of the source
+ file:
+
+ >>> benchmark = make_benchmark('my_app.c')
+ >>> env = gym.make("llvm-v0")
+ >>> env.reset(benchmark=benchmark)
+
+ The clang invocation used is roughly equivalent to:
+
+ .. code-block::
+
+ $ clang my_app.c -O0 -c -emit-llvm -o benchmark.bc
+
+ Additional compile-time arguments to clang can be provided using the
+ :code:`copt` argument:
+
+ >>> benchmark = make_benchmark('/path/to/my_app.cpp', copt=['-O2'])
+
+ If you need more fine-grained control over the options, you can directly
+ construct a :class:`ClangInvocation
+ ` to pass a list of arguments to
+ clang:
+
+ >>> benchmark = make_benchmark(
+ ClangInvocation(['/path/to/my_app.c'], timeout=10)
)
- self.require_dataset("cBench-v1")
- super().reset(*args, **kwargs)
- @staticmethod
- def make_benchmark(*args, **kwargs):
- """Alias to :func:`llvm.make_benchmark() `."""
- return make_benchmark(*args, **kwargs)
+ For multi-file programs, pass a list of inputs that will be compiled
+ separately and then linked to a single module:
+
+ >>> benchmark = make_benchmark([
+ 'main.c',
+ 'lib.cpp',
+ 'lib2.bc',
+ ])
+
+ If you already have prepared bitcode files, those can be linked and used
+ directly:
+
+ >>> benchmark = make_benchmark([
+ 'bitcode1.bc',
+ 'bitcode2.bc',
+ ])
+
+ .. note::
+
+ LLVM bitcode compatibility is
+ `not guaranteed `_,
+ so you must ensure that any precompiled bitcodes are compatible with the
+ LLVM version used by CompilerGym, which can be queried using
+ :func:`LlvmEnv.compiler_version `.
+
+ :param inputs: An input, or list of inputs.
+
+ :param copt: A list of command line options to pass to clang when
+ compiling source files.
+
+ :param system_includes: Whether to include the system standard libraries
+ during compilation jobs. This requires a system toolchain. See
+ :func:`get_system_includes`.
+
+ :param timeout: The maximum number of seconds to allow clang to run
+ before terminating.
+
+ :return: A :code:`Benchmark` instance.
+
+ :raises FileNotFoundError: If any input sources are not found.
+
+ :raises TypeError: If the inputs are of unsupported types.
+
+ :raises OSError: If a compilation job fails.
+
+ :raises TimeoutExpired: If a compilation job exceeds :code:`timeout`
+ seconds.
+ """
+ return make_benchmark(
+ inputs=inputs,
+ copt=copt,
+ system_includes=system_includes,
+ timeout=timeout,
+ )
def _make_action_space(self, name: str, entries: List[str]) -> Commandline:
flags = [
@@ -361,7 +459,7 @@ def ir_sha1(self) -> str:
Equivalent to: :code:`hashlib.sha1(env.ir.encode("utf-8")).hexdigest()`.
- :return: A 40-character hexademical sha1 string.
+ :return: A 40-character hexadecimal sha1 string.
"""
# TODO(cummins): Compute this on the service-side and add it as an
# observation space.
@@ -400,17 +498,3 @@ def render(
print(self.ir)
else:
return super().render(mode)
-
- def get_benchmark_validation_callback(
- self,
- ) -> Optional[Callable[[CompilerEnv], Iterable[ValidationError]]]:
- """Return a callback for validating a given environment state.
-
- If there is no valid callback, returns :code:`None`.
-
- :param env: An :class:`LlvmEnv` instance.
- :return: An optional callback that takes an :class:`LlvmEnv` instance as
- argument and returns an optional string containing a validation error
- message.
- """
- return get_llvm_benchmark_validation_callback(self)
diff --git a/compiler_gym/envs/llvm/llvm_rewards.py b/compiler_gym/envs/llvm/llvm_rewards.py
index 121474847..d746b092a 100644
--- a/compiler_gym/envs/llvm/llvm_rewards.py
+++ b/compiler_gym/envs/llvm/llvm_rewards.py
@@ -5,6 +5,7 @@
"""This module defines reward spaces used by the LLVM environment."""
from typing import List, Optional
+from compiler_gym.datasets import Benchmark
from compiler_gym.service import observation_t
from compiler_gym.spaces.reward import Reward
from compiler_gym.views.observation import ObservationView
@@ -35,7 +36,7 @@ def __init__(self, cost_function: str, init_cost_function: str, **kwargs):
self.init_cost_function: str = init_cost_function
self.previous_cost: Optional[observation_t] = None
- def reset(self, benchmark: str) -> None:
+ def reset(self, benchmark: Benchmark) -> None:
"""Called on env.reset(). Reset incremental progress."""
del benchmark # unused
self.previous_cost = None
@@ -64,13 +65,13 @@ def __init__(self, **kwargs):
"""Constructor."""
super().__init__(**kwargs)
self.cost_norm: Optional[observation_t] = None
- self.benchmark = None
+ self.benchmark: Benchmark = None
def reset(self, benchmark: str) -> None:
"""Called on env.reset(). Reset incremental progress."""
super().reset(benchmark)
# The benchmark has changed so we must compute a new cost normalization
- # value. If the benchamrk has not changed then the previously computed
+ # value. If the benchmark has not changed then the previously computed
# value is still valid.
if self.benchmark != benchmark:
self.cost_norm = None
diff --git a/compiler_gym/envs/llvm/service/BUILD b/compiler_gym/envs/llvm/service/BUILD
index 8e42a1668..4f08d1dd5 100644
--- a/compiler_gym/envs/llvm/service/BUILD
+++ b/compiler_gym/envs/llvm/service/BUILD
@@ -104,6 +104,7 @@ cc_library(
],
deps = [
":Benchmark",
+ ":Cost",
"//compiler_gym/util:GrpcStatusMacros",
"//compiler_gym/util:RunfilesPath",
"//compiler_gym/util:StrLenConstexpr",
diff --git a/compiler_gym/envs/llvm/service/Benchmark.cc b/compiler_gym/envs/llvm/service/Benchmark.cc
index c889eb63d..9f136dffc 100644
--- a/compiler_gym/envs/llvm/service/Benchmark.cc
+++ b/compiler_gym/envs/llvm/service/Benchmark.cc
@@ -22,13 +22,6 @@ namespace compiler_gym::llvm_service {
namespace {
-BaselineCosts getBaselineCosts(const llvm::Module& unoptimizedModule,
- const fs::path& workingDirectory) {
- BaselineCosts baselineCosts;
- setbaselineCosts(unoptimizedModule, &baselineCosts, workingDirectory);
- return baselineCosts;
-}
-
BenchmarkHash getModuleHash(const llvm::Module& module) {
BenchmarkHash hash;
llvm::SmallVector buffer;
@@ -51,8 +44,10 @@ std::unique_ptr makeModuleOrDie(llvm::LLVMContext& context, const
} // anonymous namespace
Status readBitcodeFile(const fs::path& path, Bitcode* bitcode) {
- std::ifstream ifs;
- ifs.open(path.string());
+ std::ifstream ifs(path.string());
+ if (ifs.fail()) {
+ return Status(StatusCode::NOT_FOUND, fmt::format("File not found: \"{}\"", path.string()));
+ }
ifs.seekg(0, std::ios::end);
if (ifs.fail()) {
@@ -83,7 +78,15 @@ std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitco
llvm::parseBitcodeFile(buffer, context);
if (moduleOrError) {
*status = Status::OK;
- return std::move(moduleOrError.get());
+ std::unique_ptr module = std::move(moduleOrError.get());
+
+ // Strip the module identifiers and source file names from the module to
+ // anonymize them. This is to deter learning algorithms from overfitting to
+ // benchmarks by their name.
+ module->setModuleIdentifier("-");
+ module->setSourceFileName("-");
+
+ return module;
} else {
*status = Status(StatusCode::INVALID_ARGUMENT,
fmt::format("Failed to parse LLVM bitcode: \"{}\"", name));
@@ -93,35 +96,30 @@ std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitco
// A benchmark is an LLVM module and the LLVM context that owns it.
Benchmark::Benchmark(const std::string& name, const Bitcode& bitcode,
- const fs::path& workingDirectory, std::optional bitcodePath,
- const BaselineCosts* baselineCosts)
+ const fs::path& workingDirectory, const BaselineCosts& baselineCosts)
: context_(std::make_unique()),
module_(makeModuleOrDie(*context_, bitcode, name)),
- baselineCosts_(baselineCosts ? *baselineCosts : getBaselineCosts(*module_, workingDirectory)),
+ baselineCosts_(baselineCosts),
hash_(getModuleHash(*module_)),
name_(name),
- bitcodeSize_(bitcode.size()),
- bitcodePath_(bitcodePath) {}
+ bitcodeSize_(bitcode.size()) {}
Benchmark::Benchmark(const std::string& name, std::unique_ptr context,
std::unique_ptr module, size_t bitcodeSize,
- const fs::path& workingDirectory, std::optional bitcodePath,
- const BaselineCosts* baselineCosts)
+ const fs::path& workingDirectory, const BaselineCosts& baselineCosts)
: context_(std::move(context)),
module_(std::move(module)),
- baselineCosts_(baselineCosts ? *baselineCosts : getBaselineCosts(*module_, workingDirectory)),
+ baselineCosts_(baselineCosts),
hash_(getModuleHash(*module_)),
name_(name),
- bitcodeSize_(bitcodeSize),
- bitcodePath_(bitcodePath) {}
+ bitcodeSize_(bitcodeSize) {}
std::unique_ptr Benchmark::clone(const fs::path& workingDirectory) const {
Bitcode bitcode;
llvm::raw_svector_ostream ostream(bitcode);
llvm::WriteBitcodeToFile(module(), ostream);
- return std::make_unique(name(), bitcode, workingDirectory, bitcodePath(),
- &baselineCosts());
+ return std::make_unique(name(), bitcode, workingDirectory, baselineCosts());
}
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/Benchmark.h b/compiler_gym/envs/llvm/service/Benchmark.h
index 377063e8f..713c65b99 100644
--- a/compiler_gym/envs/llvm/service/Benchmark.h
+++ b/compiler_gym/envs/llvm/service/Benchmark.h
@@ -28,7 +28,8 @@ using Bitcode = llvm::SmallString<0>;
grpc::Status readBitcodeFile(const boost::filesystem::path& path, Bitcode* bitcode);
-// Returns nullptr on error and sets status.
+// Parses the given bitcode into a module and strips the identifying ModuleID
+// and source_filename attributes. Returns nullptr on error and sets status.
std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitcode& bitcode,
const std::string& name, grpc::Status* status);
@@ -37,23 +38,17 @@ std::unique_ptr makeModule(llvm::LLVMContext& context, const Bitco
class Benchmark {
public:
Benchmark(const std::string& name, const Bitcode& bitcode,
- const boost::filesystem::path& workingDirectory,
- std::optional bitcodePath = std::nullopt,
- const BaselineCosts* baselineCosts = nullptr);
+ const boost::filesystem::path& workingDirectory, const BaselineCosts& baselineCosts);
Benchmark(const std::string& name, std::unique_ptr context,
std::unique_ptr module, size_t bitcodeSize,
- const boost::filesystem::path& workingDirectory,
- std::optional bitcodePath = std::nullopt,
- const BaselineCosts* baselineCosts = nullptr);
+ const boost::filesystem::path& workingDirectory, const BaselineCosts& baselineCosts);
// Make a copy of the benchmark.
std::unique_ptr clone(const boost::filesystem::path& workingDirectory) const;
inline const std::string& name() const { return name_; }
- inline const std::optional bitcodePath() const { return bitcodePath_; }
-
inline const size_t bitcodeSize() const { return bitcodeSize_; }
inline llvm::Module& module() { return *module_; }
@@ -90,9 +85,6 @@ class Benchmark {
const std::string name_;
// The length of the bitcode string for this benchmark.
const size_t bitcodeSize_;
- // The path of the bitcode file for this benchmark. This is optional -
- // benchmarks do not have to be backed by a file.
- const std::optional bitcodePath_;
};
} // namespace compiler_gym::llvm_service
diff --git a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
index 0ef0584c8..73509654c 100644
--- a/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
+++ b/compiler_gym/envs/llvm/service/BenchmarkFactory.cc
@@ -11,6 +11,7 @@
#include
#include
+#include "compiler_gym/envs/llvm/service/Cost.h"
#include "compiler_gym/util/GrpcStatusMacros.h"
#include "compiler_gym/util/RunfilesPath.h"
#include "compiler_gym/util/StrLenConstexpr.h"
@@ -24,10 +25,6 @@ using grpc::StatusCode;
namespace compiler_gym::llvm_service {
-static const std::string kExpectedExtension = ".bc";
-
-static const fs::path kSiteBenchmarksDir = util::getSiteDataPath("llvm/10.0.0/bitcode_benchmarks");
-
BenchmarkFactory::BenchmarkFactory(const boost::filesystem::path& workingDirectory,
std::optional rand,
size_t maxLoadedBenchmarkSize)
@@ -35,18 +32,22 @@ BenchmarkFactory::BenchmarkFactory(const boost::filesystem::path& workingDirecto
rand_(rand.has_value() ? *rand : std::mt19937_64(std::random_device()())),
loadedBenchmarksSize_(0),
maxLoadedBenchmarkSize_(maxLoadedBenchmarkSize) {
- // Register all benchmarks from the site data directory.
- if (fs::is_directory(kSiteBenchmarksDir)) {
- CRASH_IF_ERROR(scanSiteDataDirectory());
- } else {
- LOG(INFO) << "LLVM site benchmark directory not found: " << kSiteBenchmarksDir.string();
+ VLOG(2) << "BenchmarkFactory initialized";
+}
+
+Status BenchmarkFactory::getBenchmark(const std::string& uri,
+ std::unique_ptr* benchmark) {
+ // Check if the benchmark has already been loaded into memory.
+ auto loaded = benchmarks_.find(uri);
+ if (loaded != benchmarks_.end()) {
+ *benchmark = loaded->second.clone(workingDirectory_);
+ return Status::OK;
}
- VLOG(2) << "BenchmarkFactory initialized with " << numBenchmarks() << " benchmarks";
+ return Status(StatusCode::NOT_FOUND, "Benchmark not found");
}
-Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitcode,
- std::optional bitcodePath) {
+Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitcode) {
Status status;
std::unique_ptr context = std::make_unique();
std::unique_ptr module = makeModule(*context, bitcode, uri, &status);
@@ -59,34 +60,17 @@ Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitco
<< " exceeds maximum in-memory cache capacity " << maxLoadedBenchmarkSize_ << ", "
<< benchmarks_.size() << " bitcodes";
int evicted = 0;
- // Evict benchmarks until we have reduced capacity below 50%. Use a
- // bounded for loop to prevent infinite loop if we get "unlucky" and
- // have no valid candidates to unload.
+ // Evict benchmarks until we have reduced capacity below 50%.
const size_t targetCapacity = maxLoadedBenchmarkSize_ / 2;
- for (size_t i = 0; i < benchmarks_.size() * 2; ++i) {
- // We have run out of benchmarks to evict, or have freed up
- // enough capacity.
- if (!benchmarks_.size() || loadedBenchmarksSize_ < targetCapacity) {
- break;
- }
-
+ while (benchmarks_.size() && loadedBenchmarksSize_ > targetCapacity) {
// Select a cached benchmark randomly.
std::uniform_int_distribution distribution(0, benchmarks_.size() - 1);
size_t index = distribution(rand_);
auto iterator = std::next(std::begin(benchmarks_), index);
- // Check that the benchmark has an on-disk bitcode file which
- // can be loaded to re-cache this bitcode. If not, we cannot
- // evict it.
- if (!iterator->second.bitcodePath().has_value()) {
- continue;
- }
-
- // Evict the benchmark: add it to the pool of unloaded benchmarks and
- // delete it from the pool of loaded benchmarks.
+ // Evict the benchmark from the pool of loaded benchmarks.
++evicted;
loadedBenchmarksSize_ -= iterator->second.bitcodeSize();
- unloadedBitcodePaths_.insert({iterator->first, *iterator->second.bitcodePath()});
benchmarks_.erase(iterator);
}
@@ -94,241 +78,22 @@ Status BenchmarkFactory::addBitcode(const std::string& uri, const Bitcode& bitco
<< loadedBenchmarksSize_ << ", " << benchmarks_.size() << " bitcodes";
}
+ BaselineCosts baselineCosts;
+ RETURN_IF_ERROR(setBaselineCosts(*module, &baselineCosts, workingDirectory_));
+
benchmarks_.insert({uri, Benchmark(uri, std::move(context), std::move(module), bitcodeSize,
- workingDirectory_, bitcodePath)});
+ workingDirectory_, baselineCosts)});
loadedBenchmarksSize_ += bitcodeSize;
return Status::OK;
}
-Status BenchmarkFactory::addBitcodeFile(const std::string& uri,
- const boost::filesystem::path& path) {
- if (!fs::exists(path)) {
- return Status(StatusCode::NOT_FOUND, fmt::format("File not found: \"{}\"", path.string()));
- }
- unloadedBitcodePaths_[uri] = path;
- return Status::OK;
-}
-
-Status BenchmarkFactory::addBitcodeUriAlias(const std::string& src, const std::string& dst) {
- // TODO(github.com/facebookresearch/CompilerGym/issues/2): Add support
- // for additional protocols, e.g. http://.
- if (dst.rfind("file:////", 0) != 0) {
- return Status(StatusCode::INVALID_ARGUMENT,
- fmt::format("Unsupported benchmark URI protocol: \"{}\"", dst));
- }
-
- // Resolve path from file:/// protocol URI.
- const boost::filesystem::path path{dst.substr(util::strLen("file:///"))};
- return addBitcodeFile(src, path);
-}
-
-namespace {
-
-bool endsWith(const std::string& str, const std::string& suffix) {
- return str.size() >= suffix.size() &&
- str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
-}
-
-} // anonymous namespace
-
-Status BenchmarkFactory::addDirectoryOfBitcodes(const boost::filesystem::path& root) {
- VLOG(3) << "addDirectoryOfBitcodes(" << root.string() << ")";
- if (!fs::is_directory(root)) {
- return Status(StatusCode::INVALID_ARGUMENT,
- fmt::format("Directory not found: \"{}\"", root.string()));
- }
-
- // Check if there is a manifest file that we can read, rather than having to
- // enumerate the directory ourselves.
- const auto manifestPath = fs::path(root.string() + ".MANIFEST");
- if (fs::is_regular_file(manifestPath)) {
- VLOG(3) << "Reading manifest file: " << manifestPath;
- return addDirectoryOfBitcodes(root, manifestPath);
- }
-
- const auto rootPathSize = root.string().size();
- for (auto it : fs::recursive_directory_iterator(root, fs::symlink_option::recurse)) {
- if (!fs::is_regular_file(it)) {
- continue;
- }
- const std::string& path = it.path().string();
-
- if (!endsWith(path, kExpectedExtension)) {
- continue;
- }
-
- // The name of the benchmark is path, relative to the root, without the
- // file extension.
- const std::string name =
- path.substr(rootPathSize + 1, path.size() - rootPathSize - kExpectedExtension.size() - 1);
- const std::string uri = fmt::format("benchmark://{}", name);
-
- RETURN_IF_ERROR(addBitcodeFile(uri, path));
- }
-
- return Status::OK;
-}
-
-Status BenchmarkFactory::addDirectoryOfBitcodes(const boost::filesystem::path& root,
- const boost::filesystem::path& manifestPath) {
- std::ifstream infile(manifestPath.string());
- std::string relPath;
- while (std::getline(infile, relPath)) {
- if (!endsWith(relPath, kExpectedExtension)) {
- continue;
- }
-
- const fs::path path = root / relPath;
- const std::string name = relPath.substr(0, relPath.size() - kExpectedExtension.size());
- const std::string uri = fmt::format("benchmark://{}", name);
-
- RETURN_IF_ERROR(addBitcodeFile(uri, path));
- }
-
- return Status::OK;
-}
-
-Status BenchmarkFactory::getBenchmark(std::unique_ptr