Skip to content

Commit

Permalink
pass for sarl
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Dec 14, 2023
1 parent c9840a8 commit 030407c
Show file tree
Hide file tree
Showing 28 changed files with 322 additions and 134 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,5 @@ dmypy.json
_build
logs
demos
prof/
prof/
runs
25 changes: 16 additions & 9 deletions examples/sarl/ppo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,29 @@ class FeatureHandler(BaseFeature):


def feature_handler_meta_gen(env_desc, agent_id):
def f(device):
"""Return a generator of feature handler meta.
Args:
env_desc (_type_): _description_
agent_id (_type_): _description_
"""

def f(device="cpu"):
# define the data schema
_spaces = {
Episode.DONE: spaces.Discrete(1),
Episode.CUR_OBS: env_desc["observation_spaces"][agent_id],
Episode.ACTION: env_desc["action_spaces"][agent_id],
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32),
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32),
Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id],
}

# you should know the maximum of replaybuffer before training
np_memory = {
k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
k: np.zeros((10000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
}

return FeatureHandler(_spaces, np_memory, device)
return FeatureHandler(_spaces, np_memory, device=device)

return f

Expand All @@ -51,7 +58,7 @@ def f(device):

args = parser.parse_args()

trainer_config = DEFAULT_CONFIG["training_config"].copy()
trainer_config = DEFAULT_CONFIG.TRAINING_CONIG.copy()
trainer_config["total_timesteps"] = int(1e6)
trainer_config["use_cuda"] = args.use_cuda

Expand Down Expand Up @@ -80,13 +87,13 @@ def f(device):
rollout_config=RolloutConfig(
num_workers=1,
),
agent_mapping_func=lambda agent: agent,
stopping_conditions={
"training": {"max_iteration": int(1e10)},
"rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0},
"golbal": {"max_iteration": 1000, "minimum_reward_improvement": 1.0},
"rollout": {"max_iteration": 1},
"training": {"max_iteration": 1},
},
)

results = sarl_scenario.execution_plan(scenario=scenario, verbose=True)
results = sarl_scenario.execution_plan(scenario=scenario, verbose=False)

print(results)
2 changes: 2 additions & 0 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(
self.max_message_length = max_message_length

def start_server(self):
"""Launch a dataset service."""

self.server_port = find_free_port()
self.server = service_wrapper(
self.grpc_thread_num_workers,
Expand Down
16 changes: 14 additions & 2 deletions malib/backend/dataset_server/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,26 @@ def __init__(
block_size: int = None,
device: str = "cpu",
) -> None:
"""Constructing a feature handler for data preprocessing.
Args:
spaces (Dict[str, spaces.Space]): A dict of spaces
np_memory (Dict[str, np.ndarray]): A dict of memory placeholders
block_size (int, optional): Block size. Defaults to None.
device (str, optional): Device name. Defaults to "cpu".
"""

self.rw_lock = rwlock.RWLockFair()
self._device = device
self._spaces = spaces
self._block_size = min(block_size or np.iinfo(np.longlong).max, list(np_memory.values())[0].shape[0])
self._block_size = min(
block_size or np.iinfo(np.longlong).max,
list(np_memory.values())[0].shape[0],
)
self._available_size = 0
self._flag = 0
self._shared_memory = {
k: torch.from_numpy(v[:self._block_size]).to(device).share_memory_()
k: torch.from_numpy(v[: self._block_size]).to(device).share_memory_()
for k, v in np_memory.items()
}

Expand Down
17 changes: 11 additions & 6 deletions malib/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class Task:

@dataclass
class RolloutTask(Task):
strategy_specs: Dict[str, Any] = field(default_factory=dict())
stopping_conditions: Dict[str, Any] = field(default_factory=dict())
data_entrypoints: Dict[str, Any] = field(default_factory=dict())
strategy_specs: Dict[str, Any] = field(default_factory=dict)
stopping_conditions: Dict[str, Any] = field(default_factory=dict)
data_entrypoints: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_raw(
Expand All @@ -36,15 +36,20 @@ def from_raw(

@dataclass
class OptimizationTask(Task):
stop_conditions: Dict[str, Any]
stopping_conditions: Dict[str, Any]
"""stopping conditions for optimization task, e.g., max iteration, max time, etc."""

strategy_specs: Dict[str, Any] = field(default_factory=dict())
"""a dict of strategy specs, which defines the strategy spec for each agent."""
# strategy_specs: Dict[str, Any] = field(default_factory=dict)
# """a dict of strategy specs, which defines the strategy spec for each agent."""

active_agents: List[AgentID] = field(default_factory=list)
"""a list of active agents, which defines the agents that will be trained in this optimization task. None for all"""

save_interval: int = 2
"""the interval of saving checkpoints"""

model_dir: str = ""

@classmethod
def from_raw(
cls, dict_style: Union[Dict[str, Any], "OptimizationTask"], **kwargs
Expand Down
4 changes: 3 additions & 1 deletion malib/learner/indepdent_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@


class IndependentAgent(Learner):
def multiagent_post_process(self, batch: Dict[AgentID, Dict[str, torch.Tensor]]) -> Dict[str, Any]:
def multiagent_post_process(
self, batch: Dict[AgentID, Dict[str, torch.Tensor]]
) -> Dict[str, Any]:
return to_torch(batch, device=self.device)
50 changes: 44 additions & 6 deletions malib/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import time
import traceback
import os

import json
import torch
import ray

Expand All @@ -40,6 +42,7 @@
from malib.utils.logging import Logger
from malib.utils.tianshou_batch import Batch
from malib.utils.monitor import write_to_tensorboard
from malib.utils.stopping_conditions import get_stopper
from malib.remote.interface import RemoteInterface
from malib.common.task import OptimizationTask
from malib.common.strategy_spec import StrategySpec
Expand Down Expand Up @@ -107,9 +110,18 @@ def __init__(
self._governed_agents = governed_agents
self._strategy_spec = strategy_spec
self._custom_config = custom_config
# Do not add policy to strategy spec now, since we only update it
# when new checkpoint is ready.
self._policy = strategy_spec.gen_policy(device=device)

self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir)
self._model_dir = os.path.join(log_dir, "models")

if not os.path.exists(self._model_dir):
os.makedirs(self._model_dir)

# save metastate to current log_dir
self.save_metastate(log_dir)

# load policy for trainer
self._trainer: Trainer = algorithm.trainer(
Expand Down Expand Up @@ -140,6 +152,16 @@ def __init__(
self._total_epoch = 0
self._verbose = verbose

def save_metastate(self, log_dir):
with open("{}/metastate.json".format(log_dir), "w") as f:
json.dump(
{
"runtime_id": self._runtime_id,
"governed_agents": self.governed_agents,
},
f,
)

@abstractmethod
def multiagent_post_process(
self,
Expand Down Expand Up @@ -215,13 +237,12 @@ def get_interface_state(self) -> Dict[str, Any]:
"total_epoch": self._total_epoch,
"policy_num": len(self._strategy_spec),
}

def step(self, prints: bool = False):
while (
self.data_loader.dataset.readable_block_size
< self.data_loader.batch_size
self.data_loader.dataset.readable_block_size < self.data_loader.batch_size
):
time.sleep(1)
return

for data in self.data_loader:
batch_dict = self.multiagent_post_process(data)
Expand All @@ -239,10 +260,13 @@ def step(self, prints: bool = False):
prefix=f"Learner/{self._runtime_id}",
)
if prints:
print(self._total_step, step_info)
print(self._total_epoch, self._total_step, step_info)

self._total_epoch += 1

# TODO(ming): should merge step before return
return step_info_list

def train(self, task: OptimizationTask) -> Dict[str, Any]:
"""Executes a optimization task and returns the final interface state.
Expand All @@ -255,10 +279,22 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]:
"""

self.set_running(True)
stopper = get_stopper(task.stopping_conditions)

try:
while self.is_running():
self.step()
results = self.step()
if results is None: # indicates the dataset is not ready
break
if self._total_epoch % task.save_interval == 0:
ck_path = os.path.join(
self._model_dir, f"checkpoint-{self._total_epoch}.ckpt"
)
torch.save(self.policy.state_dict(), ck_path)
Logger.info("save checkpoint to {}".format(ck_path))
self.strategy_spec.register_policy_id(ck_path)
if stopper.should_stop(results):
break
except Exception as e:
Logger.warning(
f"training pipe is terminated. caused by: {traceback.format_exc()}"
Expand All @@ -270,6 +306,8 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]:
self._total_epoch, self._total_step
)
)
# hard set False to stop training
self.set_running(False)
return self.get_interface_state()

def reset(self):
Expand Down
33 changes: 23 additions & 10 deletions malib/learner/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
observation_space=group_info["observation_space"][rid],
action_space=group_info["action_space"][rid],
algorithm=algorithm,
agent_mapping_func=agent_mapping_func,
governed_agents=agents,
custom_config=learner_config.custom_config,
feature_handler_gen=learner_config.feature_handler_meta_gen(
Expand Down Expand Up @@ -152,6 +151,7 @@ def __init__(
self._agent_mapping_func = agent_mapping_func
self._learners = learners
self._thread_pool = ThreadPoolExecutor(max_workers=len(learners))
# FIXME(ming): deprecated
self._stopping_conditions = stopping_conditions

# init strategy spec
Expand Down Expand Up @@ -211,6 +211,12 @@ def runtime_ids(self) -> Tuple[str]:

return self._runtime_ids

def get_strategy_specs(self) -> Dict[str, StrategySpec]:
values = ray.get(
[v.get_strategy_spec.remote() for v in self._learners.values()]
)
return dict(zip(self._learners.keys(), values))

def add_policies(
self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1
) -> Dict[str, Type[StrategySpec]]:
Expand Down Expand Up @@ -240,22 +246,29 @@ def add_policies(

return strategy_spec_dict

def submit(self, task: OptimizationTask):
def submit(self, task: OptimizationTask, wait: bool = False):
"""Submit a training task, the manager will distribute it to the corresponding learners.
Args:
task (OptimizationTask): A task description.
"""

# retrieve learners with active agents
for aid in task.active_agents:
rid = self._agent_mapping_func(aid)
if rid not in self._learners:
raise RuntimeError(f"Agent {aid} is not registered in training manager")
else:
learner = self._learners[rid]
ray_task = learner.train.remote(task)
self.pending_tasks.append(ray_task)
rids = (
list(self._learners.keys())
if task.active_agents is None
else [self._agent_mapping_func(aid) for aid in task.active_agents]
)

for rid in rids:
learner = self._learners[rid]
ray_task = learner.train.remote(task)
self.pending_tasks.append(ray_task)
if wait:
result_list = self.wait()
return result_list
else:
return None

def retrive_results(self) -> Generator:
"""Return a generator of results.
Expand Down
Loading

0 comments on commit 030407c

Please sign in to comment.