diff --git a/test/local_test_ddp.py b/test/local_test_ddp.py deleted file mode 100644 index 1b8e60b08..000000000 --- a/test/local_test_ddp.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import copy -import os -import unittest - -import pippy.fx - -import torch -import torch.distributed.rpc as rpc -from pippy import run_pippy -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -# TODOs for implementing forward/backward/loss with schedules: -# * ability to switch between full-batch loss vs. per-microbatch loss. shen mentioned -# this might change numerics. So we should have the ability to compute loss over -# the whole minibatch rather than doing it for each micro-batch - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -def get_grad_from_executor(executor, qualname): - mod = executor.local_value().mod - if isinstance(mod, torch.nn.parallel.DistributedDataParallel): - return mod.module.get_parameter(qualname).grad - else: - return mod.get_parameter(qualname).grad - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(pp_ranks, args): - torch.manual_seed(42) - - d_hid = 50 - bs = 503 - CHUNKS = 5 - DEBUG_MASK_MINIBATCHES = True - check_numeric = True if args.cuda == 0 else False # TODO - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - def rand_zeros_or_ones(shape): - return torch.randint(0, 2, shape).float() - - class ZeroOneLinear(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.w = torch.nn.Parameter(rand_zeros_or_ones((in_dim, out_dim))) - - def forward(self, x): - return x @ self.w - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.mm_param2 = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.lin = ZeroOneLinear(d_hid, d_hid) - self.register_buffer( - "buffer", 0.00001 * rand_zeros_or_ones((bs + 100, d_hid)) - ) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - pipe_split() - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - return x - - ec = ExampleCode() - ec.to(args.device) - ec(torch.randn(bs, d_hid, device=args.device)) - ec.train() - - # TODO: works with sum, need to define semantics for e.g. mean - mse_loss = torch.nn.MSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - if args.rank == 0: - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - CHUNKS, - args.pp_group_size, - all_ranks=pp_ranks, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - checkpoint=bool(args.checkpoint), - ) - print(f"Rank {args.rank} Instantiated pipe with ranks {pp_ranks}") - - pipe_driver.init_data_parallel(dp_group_size=args.dp_group_size) - - torch.manual_seed(args.rank) - input = torch.randn(bs, d_hid, device=args.device) - target = torch.randn(bs, d_hid, device=args.device) - - # TODO: distributed optimizer - out = pipe_driver(input, target) - - print(f"Rank {args.rank} got loss value {out}") - - if not check_numeric: - print("DDP + PP API test passed") - return - - all_grad_qualnames = {k: None for k, v in ec_pipe.named_parameters()} - - pipe_grads = {} - - for name in all_grad_qualnames: - assert "split_gm." in name - _, module_name, param_qualname = name.split(".", maxsplit=2) - - assert module_name in pipe_driver.remote_stage_executor_rrefs - rank, module_rref = pipe_driver.remote_stage_executor_rrefs[module_name] - grad_value = rpc.rpc_sync( - module_rref.owner(), - get_grad_from_executor, - (module_rref, param_qualname), - ) - pipe_grads[name] = copy.deepcopy(grad_value) - - # User driver group as the DDP reference group - wrapper_ddp = torch.nn.parallel.DistributedDataParallel( - wrapper, process_group=args.driver_group - ) - - wrapper_out = wrapper_ddp(input, target) - wrapper_out.backward() - - not_close_grads = [] - ref_grads = {} - - for name in all_grad_qualnames: - remapped_qualname = ec_pipe.remap_qualname(name) - param = wrapper_ddp.module.get_parameter(remapped_qualname) - assert ( - name in pipe_grads - ), f"{name} not in pipe_grads keys {pipe_grads.keys()}" - ref_grads[name] = copy.deepcopy(param.grad) - if not torch.allclose(pipe_grads[name], ref_grads[name]): - not_close_grads.append(name) - - if len(not_close_grads): - raise AssertionError(f"Gradients not close: {not_close_grads}") - - print("Gradient equivalence test passed") - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - # in row-major - # DP ranks are contiguous rows of size `args.dp_group_size` - # PP ranks are non-contiguous columns of size `args.pp_group_size` - # - # if dp_group_size = 4 and pp_group_size = 3 - # - # 0 1 2 3 - # 4 5 6 7 - # 8 9 10 11 - # - # DP ranks are [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11] - # PP ranks are [0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11] - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # ExampleCode has two stages - args.pp_group_size = 2 - assert args.world_size % args.pp_group_size == 0 - - # Use world size to determine DDP size - args.dp_group_size = args.world_size // args.pp_group_size - print(f"Using data parallel group size: {args.dp_group_size}") - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestDDP(unittest.TestCase): - def test_ddp(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_auto_parallel.py b/test/local_test_forward_auto_parallel.py deleted file mode 100644 index 19f0edd01..000000000 --- a/test/local_test_forward_auto_parallel.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import os -import unittest - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.auto_parallelization import AutoParallelConfig, dp_auto_parallel -from pippy.IR import MultiUseParameterConfig, Pipe -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - d_hid = 512 - bs = 503 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) - self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer("buffer", torch.randn(bs + 100, d_hid)) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - x = torch.relu(x) - return {"out": x} - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - - auto_parallel_ctx = AutoParallelConfig( - n_compute_nodes=args.world_size, n_devices_per_node=1, n_microbatches=5 - ) - ec_pipe = Pipe.from_tracing( - ec, - MULTI_USE_PARAM_CONFIG, - split_policy=dp_auto_parallel(auto_parallel_ctx), - ) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - - # run with different chunk size to exercise microbatch and scheduling components - pipe_driver.chunks = 1 - pipe_driver(ec_input) - pipe_driver.chunks = 100 - pipe_driver(ec_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out["out"], ref_out["out"]) - print( - f'equivalence test passed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver.chunks = 5 - out = pipe_driver(ec_input) - ref_out = ec_pipe(ec_input) - print( - f'profiling run completed {torch.sum(out["out"])} ref {torch.sum(ref_out["out"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardAutoParallelTest(unittest.TestCase): - def test_forward_auto_parallel(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_backward.py b/test/local_test_forward_backward.py deleted file mode 100644 index 43c177a9b..000000000 --- a/test/local_test_forward_backward.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import copy -import os -import unittest - -import pippy.fx - -import torch -import torch.distributed.rpc as rpc -from pippy import run_pippy -from pippy.IR import ( - MultiUseParameterConfig, - Pipe, - pipe_split, - TrivialLossWrapper, -) -from pippy.microbatch import split_args_kwargs_into_chunks -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, - PipelineDriverInterleaved1F1B, -) - -# TODOs for implementing forward/backward/loss with schedules: -# * ability to switch between full-batch loss vs. per-microbatch loss. shen mentioned -# this might change numerics. So we should have the ability to compute loss over -# the whole minibatch rather than doing it for each micro-batch - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, - "Interleaved1F1B": PipelineDriverInterleaved1F1B, -} - - -# import ctypes -# libc = ctypes.cdll.LoadLibrary("libc.so.6") -# libc.prctl.argtypes = [ -# ctypes.c_int, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ctypes.c_ulong, -# ] -# libc.prctl.restype = ctypes.c_int -# libc.prctl(0x59616D61, -1, 0, 0, 0) - - -def get_grad_from_executor(executor, qualname): - return executor.local_value().mod.get_parameter(qualname).grad - - -def set_grad_in_executor(executor, qualname, value): - param = executor.local_value().mod.get_parameter(qualname) - param.grad = value - - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - torch.manual_seed(42) - - d_hid = 50 - bs = 503 - CHUNKS = 5 - DEBUG_MASK_MINIBATCHES = True - REF_USE_MICROBATCHES = True - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - def rand_zeros_or_ones(shape): - return torch.randint(0, 2, shape).float() - - class ZeroOneLinear(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.w = torch.nn.Parameter(rand_zeros_or_ones((in_dim, out_dim))) - - def forward(self, x): - return x @ self.w - - class ExampleCode(torch.nn.Module): - def __init__(self): - super().__init__() - self.mm_param = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.mm_param2 = torch.nn.Parameter( - rand_zeros_or_ones((d_hid, d_hid)) - ) - self.lin = ZeroOneLinear(d_hid, d_hid) - self.register_buffer( - "buffer", 0.00001 * rand_zeros_or_ones((bs + 100, d_hid)) - ) - - def forward(self, x): - x = torch.mm(x, self.mm_param) - skip_connection = x - x = torch.relu(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.mm(x, self.mm_param) + self.buffer[: x.shape[0]] - x = self.lin(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.relu(x) - x = x + skip_connection - x = torch.mm(x, self.mm_param2) - x = self.lin(x) - size = x.size() # for https://github.com/pytorch/PiPPy/issues/256 - pipe_split() - x.reshape( - size[0], size[1] - ) # for https://github.com/pytorch/PiPPy/issues/256 - x = torch.relu(x) - return x - - ec = ExampleCode() - ec.to(args.device) - ec_input = torch.randn(bs, d_hid, device=args.device) - ec(ec_input) - ec.train() - - # TODO: works with sum, need to define semantics for e.g. mean - mse_loss = torch.nn.MSELoss(reduction="sum") - wrapper = TrivialLossWrapper(ec, mse_loss) - ec_pipe = Pipe.from_tracing(wrapper, MULTI_USE_PARAM_CONFIG) - print(ec_pipe.split_gm) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - ec_pipe, - CHUNKS, - args.world_size, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - target = torch.randn(bs, d_hid, device=args.device) - - # TODO: distributed optimizer - out = pipe_driver(ec_input, target) - - all_grad_qualnames = {k: None for k, v in ec_pipe.named_parameters()} - - pipe_grads = {} - - for name in all_grad_qualnames: - assert "split_gm." in name - _, module_name, param_qualname = name.split(".", maxsplit=2) - - assert module_name in pipe_driver.remote_stage_executor_rrefs - stage_id, module_rref = pipe_driver.remote_stage_executor_rrefs[ - module_name - ] - grad_value = rpc.rpc_sync( - module_rref.owner(), - get_grad_from_executor, - (module_rref, param_qualname), - ) - pipe_grads[name] = copy.deepcopy(grad_value) - - optim = torch.optim.SGD(ec_pipe.split_gm.parameters(), lr=0.05) - optim.zero_grad() - if REF_USE_MICROBATCHES: - args_split, kwargs_split = split_args_kwargs_into_chunks( - (ec_input, target), - {}, - CHUNKS, - args_chunk_spec=None, - kwargs_chunk_spec=None, - _debug_mask_minibatches=DEBUG_MASK_MINIBATCHES, - ) - ref_outs = [] - for chunk in range(CHUNKS): - ref_outs.append(ec_pipe(*args_split[chunk])) - ref_out = torch.sum(torch.stack(ref_outs)) - else: - ref_out = ec_pipe(ec_input, target) - - # Shared parameter sync for reference. TODO: move this to actual runtime - for param_set in ec_pipe.replicated_params: - grad_values = [] - for module_name, param_qualname in param_set.items(): - grad_values.append( - ec_pipe.get_parameter( - f"split_gm.{module_name}.{param_qualname}" - ).grad - ) - - synced_value = torch.sum(torch.stack(grad_values), dim=0) - - for module_name, param_qualname in param_set.items(): - ec_pipe.get_parameter( - f"split_gm.{module_name}.{param_qualname}" - ).grad = synced_value - - # TODO: scale output - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close(out, ref_out) - print( - f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}" - ) - - not_close_grads = [] - ref_grads = {} - for name in all_grad_qualnames: - param = ec_pipe.get_parameter(name) - assert ( - name in pipe_grads - ), f"{name} not in pipe_grads keys {pipe_grads.keys()}" - ref_grads[name] = param.grad - if not torch.allclose(pipe_grads[name], param.grad): - not_close_grads.append(name) - - for name in not_close_grads: - pipe_grad = pipe_grads[name] - ref_grad = ref_grads[name] - - relative_delta = torch.abs(pipe_grad - ref_grad) / ref_grad - assert False, ( - f"Gradient for parameter {name} is not numerically close! Relative diff mean " - f"{torch.mean(relative_delta)} std {torch.std(relative_delta)} max {torch.max(relative_delta)}" - ) - - print("Gradient equivalence test passed") - - # Test equivalence with initial code as well - orig_optim = torch.optim.SGD(ec.parameters(), lr=0.05) - orig_optim.zero_grad() - orig_loss = mse_loss(ec(ec_input), target) - orig_loss.backward() - torch.testing.assert_close(out, orig_loss) - - not_close_orig_grads = [] - not_found_mappings = [] - - for name in all_grad_qualnames: - try: - remapped_qualname = ec_pipe.remap_qualname(name) - except KeyError: - not_found_mappings.append(name) - else: - orig_grad = wrapper.get_parameter(remapped_qualname).grad - pipe_grad = pipe_grads[name] - if not torch.allclose(pipe_grad, orig_grad): - not_close_orig_grads.append(name) - print(name, torch.abs(pipe_grad - orig_grad) / orig_grad) - print( - name, - torch.max(torch.abs(pipe_grad - orig_grad) / orig_grad), - ) - - assert len(not_found_mappings) == 0, ( - f"No qualname mapping found between pipelined and original " - f"model: {not_found_mappings}" - ) - - assert len(not_close_orig_grads) == 0, ( - f"Grads not close between pipelined and original " - f"model: {not_close_orig_grads}" - ) - - print("correctness checks with original module passed") - - # # # Profiling runs - # with torch.autograd.profiler_legacy.profile(enabled=PROFILING_ENABLED) as prof: - # pipe_driver._debug_mask_minibatches = False - # pipe_driver.chunks = CHUNKS - # out = pipe_driver(ec_input, target) - # ref_out = ec_pipe.split_gm(ec_input, target) - # print(f'profiling run completed {torch.sum(ref_out)} ref {torch.sum(ref_out)}') - # if PROFILING_ENABLED: - # prof.export_chrome_trace(f'{os.path.splitext(os.path.basename(__file__))[0]}.json') - - -def main(args=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 4)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args(args) - - # Interleaved 1F1B uses less ranks than number of stages - if args.schedule == "Interleaved1F1B": - args.world_size = 2 - - run_pippy(run_master, args) - - -if __name__ == "__main__": - main() - - -class LocalTestForwardAutoParallelTest(unittest.TestCase): - def test_forward_backward(self): - import random - - port = random.randint(29500, 30000) - args = [ - "--master_port", - str(port), - ] - main(args) diff --git a/test/local_test_forward_hf_bert.py b/test/local_test_forward_hf_bert.py deleted file mode 100644 index 3b5d99563..000000000 --- a/test/local_test_forward_hf_bert.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import inspect -import os - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.hf import PiPPyHFTracer -from pippy.IR import ( - annotate_split_points, - MultiUseParameterConfig, - Pipe, - PipeSplitWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, -) -from transformers import BertConfig, BertModel - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - bs = 20 - seq_length = 32 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - bert = BertModel(BertConfig()) - bert.to(args.device) - bert.eval() - bert_input = torch.zeros( - bs, seq_length, dtype=torch.long, device=args.device - ).random_(bert.config.vocab_size) - bert(bert_input) - - for i in range(bert.config.num_hidden_layers): - annotate_split_points( - bert, {f"encoder.layer.{i}": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - annotate_split_points( - bert, {"pooler": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - - input_names = bert.dummy_inputs.keys() - sig = inspect.signature(bert.forward) - concrete_args = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in input_names - } - - print("Instantiating BERT Pipeline") - bert_pipe = Pipe.from_tracing( - bert, - MULTI_USE_PARAM_CONFIG, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - - assert bert.config.num_hidden_layers + 2 == len( - list(bert_pipe.split_gm.children()) - ) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - bert_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - out = pipe_driver(bert_input) - ref_out = bert_pipe(bert_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close( - out["last_hidden_state"], ref_out["last_hidden_state"] - ) - torch.testing.assert_close( - out["pooler_output"], ref_out["pooler_output"] - ) - print( - f'equivalence test passed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver._debug_mask_minibatches = False - pipe_driver.chunks = 5 - out = pipe_driver(bert_input) - ref_out = bert_pipe(bert_input) - print( - f'profiling run completed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 14)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args() - - run_pippy(run_master, args) diff --git a/test/local_test_forward_hf_gpt2.py b/test/local_test_forward_hf_gpt2.py deleted file mode 100644 index a8beb677e..000000000 --- a/test/local_test_forward_hf_gpt2.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -import argparse -import inspect -import os - -import pippy.fx - -import torch -import torch.autograd.profiler_legacy -from pippy import run_pippy -from pippy.hf import PiPPyHFTracer -from pippy.IR import ( - annotate_split_points, - MultiUseParameterConfig, - Pipe, - PipeSplitWrapper, -) -from pippy.PipelineDriver import ( - PipelineDriver1F1B, - PipelineDriverBase, - PipelineDriverFillDrain, -) -from transformers import GPT2Config, GPT2Model - -PROFILING_ENABLED = True -CHECK_NUMERIC_EQUIVALENCE = True - -schedules = { - "FillDrain": PipelineDriverFillDrain, - "1F1B": PipelineDriver1F1B, -} - -pippy.fx.Tracer.proxy_buffer_attributes = True - - -def run_master(_, args): - bs = 20 - seq_length = 32 - - MULTI_USE_PARAM_CONFIG = ( - MultiUseParameterConfig.REPLICATE - if args.replicate - else MultiUseParameterConfig.TRANSMIT - ) - print(f"REPLICATE config: {args.replicate} -> {MULTI_USE_PARAM_CONFIG}") - - print("Using schedule:", args.schedule) - - gpt2 = GPT2Model(GPT2Config(use_cache=False)) - gpt2.to(args.device) - gpt2.eval() - gpt2_input = torch.zeros( - bs, seq_length, dtype=torch.long, device=args.device - ).random_(gpt2.config.vocab_size) - - for i in range(gpt2.config.n_layer): - annotate_split_points( - gpt2, {f"h.{i}": PipeSplitWrapper.SplitPoint.BEGINNING} - ) - annotate_split_points(gpt2, {"ln_f": PipeSplitWrapper.SplitPoint.BEGINNING}) - - input_names = gpt2.dummy_inputs.keys() - sig = inspect.signature(gpt2.forward) - concrete_args = { - p.name: p.default - for p in sig.parameters.values() - if p.name not in input_names - } - - print("Instantiating GPT2 Pipeline") - gpt2_pipe = Pipe.from_tracing( - gpt2, - MULTI_USE_PARAM_CONFIG, - tracer=PiPPyHFTracer(), - concrete_args=concrete_args, - ) - - assert gpt2.config.n_layer + 2 == len(list(gpt2_pipe.split_gm.children())) - - pipe_driver: PipelineDriverBase = schedules[args.schedule]( - gpt2_pipe, - 5, - args.world_size, - _debug_mask_minibatches=True, - _record_mem_dumps=bool(args.record_mem_dumps), - checkpoint=bool(args.checkpoint), - ) - - # # Warm up and correctness runs - print( - "Running GPT2 pipeline. NB: if this is too slow, set OMP_NUM_THREADS to a higher value" - ) - out = pipe_driver(gpt2_input) - print("Running reference pipeline") - ref_out = gpt2_pipe(gpt2_input) - - if CHECK_NUMERIC_EQUIVALENCE: - torch.testing.assert_close( - out["last_hidden_state"], ref_out["last_hidden_state"] - ) - print( - f'equivalence test passed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - - # # Profiling runs - with torch.autograd.profiler_legacy.profile( - enabled=PROFILING_ENABLED - ) as prof: - pipe_driver._debug_mask_minibatches = False - pipe_driver.chunks = 5 - out = pipe_driver(gpt2_input) - ref_out = gpt2_pipe(gpt2_input) - print( - f'profiling run completed {torch.sum(out["last_hidden_state"])} ref {torch.sum(ref_out["last_hidden_state"])}' - ) - if PROFILING_ENABLED: - prof.export_chrome_trace( - f"{os.path.splitext(os.path.basename(__file__))[0]}.json" - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 14)) - ) - parser.add_argument("--rank", type=int, default=int(os.getenv("RANK", -1))) - parser.add_argument( - "--master_addr", type=str, default=os.getenv("MASTER_ADDR", "localhost") - ) - parser.add_argument( - "--master_port", type=str, default=os.getenv("MASTER_PORT", "29500") - ) - parser.add_argument( - "-s", - "--schedule", - type=str, - default=list(schedules.keys())[0], - choices=schedules.keys(), - ) - parser.add_argument( - "--replicate", type=int, default=int(os.getenv("REPLICATE", "0")) - ) - parser.add_argument( - "--cuda", type=int, default=int(torch.cuda.is_available()) - ) - parser.add_argument( - "--record_mem_dumps", type=int, default=0, choices=[0, 1] - ) - parser.add_argument("--checkpoint", type=int, default=0, choices=[0, 1]) - args = parser.parse_args() - - run_pippy(run_master, args)