Skip to content

Commit

Permalink
feat: improve autotune speed metrics measurement for better accuracy (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shjwudp authored Jul 2, 2021
1 parent a858504 commit e4ee5ee
Show file tree
Hide file tree
Showing 17 changed files with 485 additions and 278 deletions.
16 changes: 15 additions & 1 deletion bagua/bagua_define.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
from typing import List
import sys

if sys.version_info >= (3, 9):
from typing import TypedDict # pytype: disable=not-supported-yet
else:
Expand All @@ -20,6 +21,16 @@ class TensorDeclaration(TypedDict):
dtype: TensorDtype


def get_tensor_declaration_bytes(td: TensorDeclaration) -> int:
dtype_unit_size = {
TensorDtype.F32.value: 4,
TensorDtype.F16.value: 2,
TensorDtype.U8.value: 1,
}

return td["num_elements"] * dtype_unit_size[td["dtype"]]


class BaguaHyperparameter(BaseModel):
"""
Structured all bagua hyperparameters
Expand All @@ -31,5 +42,8 @@ class BaguaHyperparameter(BaseModel):
def update(self, param_dict: dict):
tmp = self.dict()
tmp.update(param_dict)
self.parse_obj(tmp)
for key, value in param_dict.items():
if key in tmp:
self.__dict__[key] = value

return self
8 changes: 7 additions & 1 deletion bagua/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def parse_args():
help="Bagua automatic super parameter search level. The higher the level, the better the theoretical effect, and the longer it takes",
)
parser.add_argument(
"--autotune_log_file", type=str, default="/tmp/bagua.autotune.log"
"--autotune_logfile", type=str, default="/tmp/bagua_autotune.log"
)
parser.add_argument(
"--report_metrics",
Expand Down Expand Up @@ -147,6 +147,12 @@ def set_bagua_env(args, current_env):
current_env["BAGUA_SERVICE_PORT"] = str(args.bagua_service_port)
current_env["BAGUA_DEFAULT_BUCKET_SIZE"] = str(args.default_bucket_size)
current_env["BAGUA_AUTOTUNE"] = str(args.autotune_level)
current_env["BAGUA_AUTOTUNE_MAX_SAMPLES"] = str(args.autotune_max_samples)
current_env["BAGUA_AUTOTUNE_SAMPLING_CONFIDENCE_TIME_S"] = str(
args.autotune_sampling_confidence_time
)
current_env["BAGUA_AUTOTUNE_WARMUP_TIME_S"] = str(args.autotune_warmup_time)
current_env["BAGUA_AUTOTUNE_LOGFILE_PATH"] = args.autotune_logfile


def main():
Expand Down
14 changes: 13 additions & 1 deletion bagua/distributed/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,13 @@ def get_args_parser() -> ArgumentParser:
default=0,
type=int,
help="Bagua automatic hyperparameters search level. The higher the level, the larger the "
"hyperparameter search space, and the longer time it takes.",
"hyperparameter search space, and the longer time it takes. Currently supported levels are 0 and 1.",
)
parser.add_argument("--autotune_max_samples", type=int, default=60)
parser.add_argument("--autotune_sampling_confidence_time", type=float, default=5.0)
parser.add_argument("--autotune_warmup_time", type=float, default=30.0)
parser.add_argument(
"--autotune_logfile", type=str, default="/tmp/bagua_autotune.log"
)

#
Expand Down Expand Up @@ -562,6 +568,12 @@ def set_bagua_env(args, current_env):
current_env["BAGUA_SERVICE_PORT"] = str(args.bagua_service_port)
current_env["BAGUA_DEFAULT_BUCKET_SIZE"] = str(args.default_bucket_size)
current_env["BAGUA_AUTOTUNE"] = str(args.autotune_level)
current_env["BAGUA_AUTOTUNE_MAX_SAMPLES"] = str(args.autotune_max_samples)
current_env["BAGUA_AUTOTUNE_SAMPLING_CONFIDENCE_TIME_S"] = str(
args.autotune_sampling_confidence_time
)
current_env["BAGUA_AUTOTUNE_WARMUP_TIME_S"] = str(args.autotune_warmup_time)
current_env["BAGUA_AUTOTUNE_LOGFILE_PATH"] = args.autotune_logfile


def run(args):
Expand Down
File renamed without changes.
Loading

0 comments on commit e4ee5ee

Please sign in to comment.