diff --git a/docs/custom-load-shape.rst b/docs/custom-load-shape.rst index 696ecca56c..340bce8378 100644 --- a/docs/custom-load-shape.rst +++ b/docs/custom-load-shape.rst @@ -40,3 +40,34 @@ This functionality is further demonstrated in the `examples on github `_. + + +Extend your shape with custom users +----------------------------------- + +Extending the return value of the ``tick()`` with the argument ``user_classes`` makes it possible to pick the users being created for a ``tick()`` specifically. + +.. code-block:: python + + class StagesShapeWithCustomUsers(LoadTestShape): + + stages = [ + {"duration": 10, "users": 10, "spawn_rate": 10, "user_classes": [UserA]}, + {"duration": 30, "users": 50, "spawn_rate": 10, "user_classes": [UserA, UserB]}, + {"duration": 60, "users": 100, "spawn_rate": 10, "user_classes": [UserB]}, + {"duration": 120, "users": 100, "spawn_rate": 10, "user_classes": [UserA,UserB]}, + + def tick(self): + run_time = self.get_run_time() + + for stage in self.stages: + if run_time < stage["duration"]: + try: + tick_data = (stage["users"], stage["spawn_rate"], stage["user_classes"]) + except: + tick_data = (stage["users"], stage["spawn_rate"]) + return tick_data + + return None + +This shape would create create in the first 10 seconds 10 User of ``UserA``. In the next twenty seconds 40 of type ``UserA / UserB`` and this continues until the stages end. \ No newline at end of file diff --git a/examples/custom_shape/staging_user_classes.py b/examples/custom_shape/staging_user_classes.py new file mode 100644 index 0000000000..12b381142b --- /dev/null +++ b/examples/custom_shape/staging_user_classes.py @@ -0,0 +1,58 @@ +from locust import HttpUser, TaskSet, task, constant +from locust import LoadTestShape + + +class UserTasks(TaskSet): + @task + def get_root(self): + self.client.get("/") + + +class WebsiteUserA(HttpUser): + wait_time = constant(0.5) + tasks = [UserTasks] + + +class WebsiteUserB(HttpUser): + wait_time = constant(0.5) + tasks = [UserTasks] + + +class StagesShapeWithCustomUsers(LoadTestShape): + """ + A simply load test shape class that has different user and spawn_rate at + different stages. + + Keyword arguments: + + stages -- A list of dicts, each representing a stage with the following keys: + duration -- When this many seconds pass the test is advanced to the next stage + users -- Total user count + spawn_rate -- Number of users to start/stop per second + stop -- A boolean that can stop that test at a specific stage + + stop_at_end -- Can be set to stop once all stages have run. + """ + + stages = [ + {"duration": 60, "users": 10, "spawn_rate": 10, "user_classes": [WebsiteUserA]}, + {"duration": 100, "users": 50, "spawn_rate": 10, "user_classes": [WebsiteUserB]}, + {"duration": 180, "users": 100, "spawn_rate": 10, "user_classes": [WebsiteUserA]}, + {"duration": 220, "users": 30, "spawn_rate": 10}, + {"duration": 230, "users": 10, "spawn_rate": 10}, + {"duration": 240, "users": 1, "spawn_rate": 1}, + ] + + def tick(self): + run_time = self.get_run_time() + + for stage in self.stages: + if run_time < stage["duration"]: + # Not the smartest solution, TODO: find something better + try: + tick_data = (stage["users"], stage["spawn_rate"], stage["user_classes"]) + except KeyError: + tick_data = (stage["users"], stage["spawn_rate"]) + return tick_data + + return None diff --git a/locust/dispatch.py b/locust/dispatch.py index f4abb7b566..51594f1bec 100644 --- a/locust/dispatch.py +++ b/locust/dispatch.py @@ -56,6 +56,7 @@ def __init__(self, worker_nodes: List["WorkerNode"], user_classes: List[Type[Use """ self._worker_nodes = worker_nodes self._sort_workers() + self._original_user_classes = sorted(user_classes, key=attrgetter("__name__")) self._user_classes = sorted(user_classes, key=attrgetter("__name__")) assert len(user_classes) > 0 @@ -163,13 +164,18 @@ def _dispatcher(self) -> Generator[Dict[str, Dict[str, int]], None, None]: self._dispatch_in_progress = False - def new_dispatch(self, target_user_count: int, spawn_rate: float) -> None: + def new_dispatch(self, target_user_count: int, spawn_rate: float, user_classes: Optional[List] = None) -> None: """ Initialize a new dispatch cycle. :param target_user_count: The desired user count at the end of the dispatch cycle :param spawn_rate: The spawn rate + :param user_classes: The user classes to be used for the new dispatch """ + if user_classes is not None and self._user_classes != sorted(user_classes, key=attrgetter("__name__")): + self._user_classes = sorted(user_classes, key=attrgetter("__name__")) + self._user_generator = self._user_gen() + self._target_user_count = target_user_count self._spawn_rate = spawn_rate @@ -224,7 +230,7 @@ def _prepare_rebalance(self) -> None: # Reset users before recalculating since the current users is used to calculate how many # fixed users to add. self._users_on_workers = { - worker_node.id: {user_class.__name__: 0 for user_class in self._user_classes} + worker_node.id: {user_class.__name__: 0 for user_class in self._original_user_classes} for worker_node in self._worker_nodes } self._try_dispatch_fixed = True @@ -325,7 +331,7 @@ def _distribute_users( worker_gen = itertools.cycle(self._worker_nodes) users_on_workers = { - worker_node.id: {user_class.__name__: 0 for user_class in self._user_classes} + worker_node.id: {user_class.__name__: 0 for user_class in self._original_user_classes} for worker_node in self._worker_nodes } diff --git a/locust/runners.py b/locust/runners.py index 83560c8c74..a0f1cfce34 100644 --- a/locust/runners.py +++ b/locust/runners.py @@ -28,6 +28,7 @@ Type, Any, cast, + Union, ) from uuid import uuid4 @@ -64,7 +65,6 @@ logger = logging.getLogger(__name__) - STATE_INIT, STATE_SPAWNING, STATE_RUNNING, STATE_CLEANUP, STATE_STOPPING, STATE_STOPPED, STATE_MISSING = [ "ready", "spawning", @@ -84,7 +84,6 @@ CONNECT_TIMEOUT = 5 CONNECT_RETRY_COUNT = 60 - greenlet_exception_handler = greenlet_exception_logger(logger) @@ -119,7 +118,7 @@ def __init__(self, environment: "Environment") -> None: self.state = STATE_INIT self.spawning_greenlet: Optional[gevent.Greenlet] = None self.shape_greenlet: Optional[gevent.Greenlet] = None - self.shape_last_state: Optional[Tuple[int, float]] = None + self.shape_last_tick: Union[Tuple[int, float], Tuple[int, float, Optional[List[Type[User]]]], None] = None self.current_cpu_usage: int = 0 self.cpu_warning_emitted: bool = False self.worker_cpu_warning_emitted: bool = False @@ -330,7 +329,9 @@ def monitor_cpu_and_memory(self) -> NoReturn: gevent.sleep(CPU_MONITOR_INTERVAL) @abstractmethod - def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: + def start( + self, user_count: int, spawn_rate: float, wait: bool = False, user_classes: Optional[List[Type[User]]] = None + ) -> None: ... def start_shape(self) -> None: @@ -351,20 +352,26 @@ def start_shape(self) -> None: def shape_worker(self) -> None: logger.info("Shape worker starting") while self.state == STATE_INIT or self.state == STATE_SPAWNING or self.state == STATE_RUNNING: - new_state = self.environment.shape_class.tick() if self.environment.shape_class is not None else None - if new_state is None: + current_tick: Union[Tuple[int, float], Tuple[int, float, Optional[List[Type[User]]]], None] = ( + self.environment.shape_class.tick() if self.environment.shape_class is not None else None + ) + if current_tick is None: logger.info("Shape test stopping") if self.environment.parsed_options and self.environment.parsed_options.headless: self.quit() else: self.stop() self.shape_greenlet = None - self.shape_last_state = None + self.shape_last_tick = None return - elif self.shape_last_state == new_state: + elif self.shape_last_tick == current_tick: gevent.sleep(1) else: - user_count, spawn_rate = new_state + if len(current_tick) == 2: + user_count, spawn_rate = current_tick # type: ignore + user_classes = None + else: + user_count, spawn_rate, user_classes = current_tick # type: ignore logger.info("Shape test updating to %d users at %.2f spawn rate" % (user_count, spawn_rate)) # TODO: This `self.start()` call is blocking until the ramp-up is completed. This can leads # to unexpected behaviours such as the one in the following example: @@ -379,8 +386,8 @@ def shape_worker(self) -> None: # We should probably use a `gevent.timeout` with a duration a little over # `(user_count - prev_user_count) / spawn_rate` in order to limit the runtime # of each load test shape stage. - self.start(user_count=user_count, spawn_rate=spawn_rate) - self.shape_last_state = new_state + self.start(user_count=user_count, spawn_rate=spawn_rate, user_classes=user_classes) + self.shape_last_tick = current_tick def stop(self) -> None: """ @@ -403,7 +410,7 @@ def stop(self) -> None: if self.shape_greenlet is not None: self.shape_greenlet.kill(block=True) self.shape_greenlet = None - self.shape_last_state = None + self.shape_last_tick = None self.stop_users(self.user_classes_count) @@ -463,7 +470,7 @@ def on_user_error(user_instance, exception, tb): self.environment.events.user_error.add_listener(on_user_error) - def _start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: + def _start(self, user_count: int, spawn_rate: float, wait: bool = False, user_classes: list = None) -> None: """ Start running a load test @@ -472,6 +479,8 @@ def _start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None :param wait: If True calls to this method will block until all users are spawned. If False (the default), a greenlet that spawns the users will be started and the call to this method will return immediately. + :param user_classes: The user classes to be dispatched, None indicates to use the classes the dispatcher was + invoked with. """ self.target_user_count = user_count @@ -500,7 +509,7 @@ def _start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None logger.info("Ramping to %d users at a rate of %.2f per second" % (user_count, spawn_rate)) - cast(UsersDispatcher, self._users_dispatcher).new_dispatch(user_count, spawn_rate) + cast(UsersDispatcher, self._users_dispatcher).new_dispatch(user_count, spawn_rate, user_classes) try: for dispatched_users in self._users_dispatcher: @@ -542,7 +551,9 @@ def _start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None self.environment.events.spawning_complete.fire(user_count=sum(self.target_user_classes_count.values())) - def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: + def start( + self, user_count: int, spawn_rate: float, wait: bool = False, user_classes: Optional[List[Type[User]]] = None + ) -> None: if spawn_rate > 100: logger.warning( "Your selected spawn rate is very high (>100), and this is known to sometimes cause issues. Do you really need to ramp up that fast?" @@ -551,7 +562,9 @@ def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: if self.spawning_greenlet: # kill existing spawning_greenlet before we start a new one self.spawning_greenlet.kill(block=True) - self.spawning_greenlet = self.greenlet.spawn(lambda: self._start(user_count, spawn_rate, wait=wait)) + self.spawning_greenlet = self.greenlet.spawn( + lambda: self._start(user_count, spawn_rate, wait=wait, user_classes=user_classes) + ) self.spawning_greenlet.link_exception(greenlet_exception_handler) def stop(self) -> None: @@ -729,7 +742,9 @@ def cpu_log_warning(self) -> bool: warning_emitted = True return warning_emitted - def start(self, user_count: int, spawn_rate: float, wait=False) -> None: + def start( + self, user_count: int, spawn_rate: float, wait=False, user_classes: Optional[List[Type[User]]] = None + ) -> None: self.spawning_completed = False self.target_user_count = user_count @@ -771,7 +786,9 @@ def start(self, user_count: int, spawn_rate: float, wait=False) -> None: self.update_state(STATE_SPAWNING) - self._users_dispatcher.new_dispatch(target_user_count=user_count, spawn_rate=spawn_rate) + self._users_dispatcher.new_dispatch( + target_user_count=user_count, spawn_rate=spawn_rate, user_classes=user_classes + ) try: for dispatched_users in self._users_dispatcher: @@ -872,7 +889,7 @@ def stop(self, send_stop_to_client: bool = True) -> None: ): self.shape_greenlet.kill(block=True) self.shape_greenlet = None - self.shape_last_state = None + self.shape_last_tick = None self._users_dispatcher = None @@ -1204,7 +1221,9 @@ def on_user_error(user_instance: User, exception: Exception, tb: TracebackType) self.environment.events.user_error.add_listener(on_user_error) - def start(self, user_count: int, spawn_rate: float, wait: bool = False) -> None: + def start( + self, user_count: int, spawn_rate: float, wait: bool = False, user_classes: Optional[List[Type[User]]] = None + ) -> None: raise NotImplementedError("use start_worker") def start_worker(self, user_classes_count: Dict[str, int], **kwargs) -> None: diff --git a/locust/shape.py b/locust/shape.py index 2561386da0..ed61e311b2 100644 --- a/locust/shape.py +++ b/locust/shape.py @@ -1,5 +1,7 @@ import time -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Type, Union + +from . import User from .runners import Runner @@ -33,13 +35,13 @@ def get_current_user_count(self): """ return self.runner.user_count - def tick(self) -> Optional[Tuple[int, float]]: + def tick(self) -> Union[Tuple[int, float], Tuple[int, float, Optional[List[Type[User]]]], None]: """ Returns a tuple with 2 elements to control the running load test: user_count -- Total user count spawn_rate -- Number of users to start/stop per second when changing number of users - + user_classes -- None or a List of userclasses to be spawned in it tick If `None` is returned then the running load test will be stopped. """ diff --git a/locust/test/test_dispatch.py b/locust/test/test_dispatch.py index d885c57567..b5d3b0bfa7 100644 --- a/locust/test/test_dispatch.py +++ b/locust/test/test_dispatch.py @@ -1009,6 +1009,49 @@ class User3(User): delta = time.perf_counter() - ts self.assertTrue(0 <= delta <= _TOLERANCE, delta) + # def test_ramp_down_users_on_workers_respecting_weight(self): + # class User1(User): + # weight = 1 + # + # class User2(User): + # weight = 1 + # + # class User3(User): + # weight = 1 + # + # user_classes = [User1, User2, User3] + # workers = [WorkerNode(str(i + 1)) for i in range(3)] + # + # user_dispatcher = UsersDispatcher(worker_nodes= workers, user_classes = user_classes) + # user_dispatcher.new_dispatch(target_user_count=7, spawn_rate=7) + # + # dispatched_users = next(user_dispatcher) + # self.assertDictEqual(dispatched_users, + # { + # "1": {"User1": 3, "User2": 0, "User3": 0}, + # "2": {"User1": 0, "User2": 2, "User3": 0}, + # "3": {"User1": 0, "User2": 0, "User3": 2} + # }) + # + # user_dispatcher.new_dispatch(target_user_count=16, spawn_rate=9) + # dispatched_users = next(user_dispatcher) + # self.assertDictEqual(dispatched_users, + # { + # "1": {"User1": 6, "User2": 0, "User3": 0}, + # "2": {"User1": 0, "User2": 5, "User3": 0}, + # "3": {"User1": 0, "User2": 0, "User3": 5} + # }) + # + # user_dispatcher.new_dispatch(target_user_count=3, spawn_rate=15) + # dispatched_users = next(user_dispatcher) + # self.assertDictEqual(dispatched_users, + # { + # "1": {"User1": 1, "User2": 0, "User3": 0}, + # "2": {"User1": 0, "User2": 1, "User3": 0}, + # "3": {"User1": 0, "User2": 0, "User3": 1} + # }) + # + def test_ramp_down_users_to_3_workers_with_spawn_rate_of_1(self): class User1(User): weight = 1 @@ -3591,6 +3634,434 @@ class User5(User): ) +class TestRampUpDifferentUsers(unittest.TestCase): + def test_ramp_up_different_users_for_each_dispatch(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher(worker_nodes=[worker_node1], user_classes=[User1, User2, User3]) + + user_dispatcher.new_dispatch(target_user_count=3, spawn_rate=3) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 1, "User2": 1, "User3": 1}}) + user_dispatcher.new_dispatch(target_user_count=4, spawn_rate=1, user_classes=[User1]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 2, "User2": 1, "User3": 1}}) + + user_dispatcher.new_dispatch(target_user_count=5, spawn_rate=1, user_classes=[User2]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 2, "User2": 2, "User3": 1}}) + + user_dispatcher.new_dispatch(target_user_count=6, spawn_rate=1, user_classes=[User3]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 2, "User2": 2, "User3": 2}}) + + def test_ramp_up_only_one_kind_of_user(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher(worker_nodes=[worker_node1], user_classes=[User1, User2, User3]) + + user_dispatcher.new_dispatch(target_user_count=10, spawn_rate=10, user_classes=[User2]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 0, "User2": 10, "User3": 0}}) + + def test_ramp_up_first_half_user1_second_half_user2(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher(worker_nodes=[worker_node1], user_classes=[User1, User2, User3]) + + user_dispatcher.new_dispatch(target_user_count=10, spawn_rate=10, user_classes=[User2]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 0, "User2": 10, "User3": 0}}) + + user_dispatcher.new_dispatch(target_user_count=40, spawn_rate=30, user_classes=[User3]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 0, "User2": 10, "User3": 30}}) + + def test_ramp_up_first_one_user_then_all_classes(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher(worker_nodes=[worker_node1], user_classes=[User1, User2, User3]) + + user_dispatcher.new_dispatch(target_user_count=10, spawn_rate=10, user_classes=[User2]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 0, "User2": 10, "User3": 0}}) + + user_dispatcher.new_dispatch(target_user_count=40, spawn_rate=30, user_classes=[User1, User2, User3]) + self.assertDictEqual(next(user_dispatcher), {"1": {"User1": 10, "User2": 20, "User3": 10}}) + + def test_ramp_up_different_users_each_dispatch_multiple_worker(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + worker_node2 = WorkerNode("2") + worker_node3 = WorkerNode("3") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher( + worker_nodes=[worker_node1, worker_node2, worker_node3], user_classes=[User1, User2, User3] + ) + + user_dispatcher.new_dispatch(target_user_count=9, spawn_rate=9) + self.assertDictEqual( + next(user_dispatcher), + { + "1": {"User1": 3, "User2": 0, "User3": 0}, + "2": {"User1": 0, "User2": 3, "User3": 0}, + "3": {"User1": 0, "User2": 0, "User3": 3}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=12, spawn_rate=3, user_classes=[User3]) + self.assertDictEqual( + next(user_dispatcher), + { + "1": {"User1": 3, "User2": 0, "User3": 1}, + "2": {"User1": 0, "User2": 3, "User3": 1}, + "3": {"User1": 0, "User2": 0, "User3": 4}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=15, spawn_rate=3, user_classes=[User2]) + self.assertDictEqual( + next(user_dispatcher), + { + "1": {"User1": 3, "User2": 1, "User3": 1}, + "2": {"User1": 0, "User2": 4, "User3": 1}, + "3": {"User1": 0, "User2": 1, "User3": 4}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=18, spawn_rate=3, user_classes=[User1]) + self.assertDictEqual( + next(user_dispatcher), + { + "1": {"User1": 4, "User2": 1, "User3": 1}, + "2": {"User1": 1, "User2": 4, "User3": 1}, + "3": {"User1": 1, "User2": 1, "User3": 4}, + }, + ) + + def test_ramp_up_one_user_class_multiple_worker(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_node1 = WorkerNode("1") + worker_node2 = WorkerNode("2") + worker_node3 = WorkerNode("3") + + sleep_time = 0.2 + + user_dispatcher = UsersDispatcher( + worker_nodes=[worker_node1, worker_node2, worker_node3], user_classes=[User1, User2, User3] + ) + + user_dispatcher.new_dispatch(target_user_count=60, spawn_rate=60, user_classes=[User2]) + self.assertDictEqual( + next(user_dispatcher), + { + "1": {"User1": 0, "User2": 20, "User3": 0}, + "2": {"User1": 0, "User2": 20, "User3": 0}, + "3": {"User1": 0, "User2": 20, "User3": 0}, + }, + ) + + def test_ramp_down_custom_user_classes_respect_weighting(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + worker_nodes = [WorkerNode(str(i + 1)) for i in range(3)] + user_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=[User1, User2, User3]) + + user_dispatcher.new_dispatch(target_user_count=20, spawn_rate=20, user_classes=[User3]) + dispatched_users = next(user_dispatcher) + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 0, "User3": 7}, + "2": {"User1": 0, "User2": 0, "User3": 7}, + "3": {"User1": 0, "User2": 0, "User3": 6}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=9, spawn_rate=20, user_classes=[User3]) + dispatched_users = next(user_dispatcher) + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 0, "User3": 3}, + "2": {"User1": 0, "User2": 0, "User3": 3}, + "3": {"User1": 0, "User2": 0, "User3": 3}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=3, spawn_rate=20, user_classes=[User1, User2, User3]) + dispatched_users = next(user_dispatcher) + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 0, "User3": 1}, + "2": {"User1": 0, "User2": 0, "User3": 1}, + "3": {"User1": 0, "User2": 0, "User3": 1}, + }, + ) + + user_dispatcher.new_dispatch(target_user_count=21, spawn_rate=21, user_classes=[User1, User2, User3]) + dispatched_users = next(user_dispatcher) + print(dispatched_users) + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 6, "User3": 1}, # 7 + "2": {"User1": 0, "User2": 0, "User3": 7}, # 7 + "3": {"User1": 6, "User2": 0, "User3": 1}, # 7 + }, + ) + + user_dispatcher.new_dispatch(target_user_count=9, spawn_rate=20, user_classes=[User1, User2, User3]) + dispatched_users = next(user_dispatcher) + + # this is disrespecting the weighting + + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 2, "User3": 1}, + "2": {"User1": 0, "User2": 0, "User3": 3}, + "3": {"User1": 2, "User2": 0, "User3": 1}, + }, + ) + + def test_remove_worker_during_ramp_up_custom_classes(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + user_classes = [User1, User2, User3] + + worker_nodes = [WorkerNode(str(i + 1)) for i in range(3)] + + users_dispatcher = UsersDispatcher(worker_nodes=worker_nodes, user_classes=user_classes) + + sleep_time = 0.2 # Speed-up test + + users_dispatcher.new_dispatch(target_user_count=9, spawn_rate=3, user_classes=[User2]) + users_dispatcher._wait_between_dispatch = sleep_time + + # Dispatch iteration 1 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(0 <= delta <= _TOLERANCE, delta) + self.assertDictEqual( + dispatched_users, + { + "1": {"User1": 0, "User2": 1, "User3": 0}, + "2": {"User1": 0, "User2": 1, "User3": 0}, + "3": {"User1": 0, "User2": 1, "User3": 0}, + }, + ) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 0, "User2": 3, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 1) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 1) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 1) + + # Dispatch iteration 2 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 0, "User2": 6, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 2) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 2) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 2) + + self.assertFalse(users_dispatcher._rebalance) + + users_dispatcher.remove_worker(worker_nodes[1]) + + self.assertTrue(users_dispatcher._rebalance) + + # Re-balance + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(0 <= delta <= _TOLERANCE, f"Expected re-balance dispatch to be instantaneous but got {delta}s") + self.assertDictEqual( + dispatched_users, {"1": {"User1": 0, "User2": 3, "User3": 0}, "3": {"User1": 0, "User2": 3, "User3": 0}} + ) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 0, "User2": 6, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 3) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 3) + + self.assertFalse(users_dispatcher._rebalance) + + # Dispatch iteration 3 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 0, "User2": 9, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 5) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 4) + + # New dispatch + users_dispatcher.new_dispatch(16, 7, [User3]) + dispatched_users = next(users_dispatcher) + self.assertDictEqual( + dispatched_users, {"1": {"User1": 0, "User2": 5, "User3": 3}, "3": {"User1": 0, "User2": 4, "User3": 4}} + ) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 0, "User2": 9, "User3": 7}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 8) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 8) + + def test_add_worker_during_ramp_up_custom_classes(self): + class User1(User): + weight = 1 + + class User2(User): + weight = 1 + + class User3(User): + weight = 1 + + user_classes = [User1, User2, User3] + + worker_nodes = [WorkerNode(str(i + 1)) for i in range(3)] + + users_dispatcher = UsersDispatcher(worker_nodes=[worker_nodes[0], worker_nodes[2]], user_classes=user_classes) + + sleep_time = 0.2 # Speed-up test + + users_dispatcher.new_dispatch(target_user_count=11, spawn_rate=3, user_classes=[User1]) + users_dispatcher._wait_between_dispatch = sleep_time + + # Dispatch iteration 1 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(0 <= delta <= _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 3, "User2": 0, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 2) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 1) + + # Dispatch iteration 2 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 6, "User2": 0, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 3) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 3) + + self.assertFalse(users_dispatcher._rebalance) + + users_dispatcher.add_worker(worker_nodes[1]) + + self.assertTrue(users_dispatcher._rebalance) + + # Re-balance + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(0 <= delta <= _TOLERANCE, f"Expected re-balance dispatch to be instantaneous but got {delta}s") + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 6, "User2": 0, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 2) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 2) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 2) + + self.assertFalse(users_dispatcher._rebalance) + + # Dispatch iteration 3 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 9, "User2": 0, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 3) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 3) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 3) + + # Dispatch iteration 4 + ts = time.perf_counter() + dispatched_users = next(users_dispatcher) + delta = time.perf_counter() - ts + self.assertTrue(sleep_time - _TOLERANCE <= delta <= sleep_time + _TOLERANCE, delta) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 11, "User2": 0, "User3": 0}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 4) + # without host-based balancing the following two values would be reversed + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 4) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 3) + + # New Dispatch + users_dispatcher.new_dispatch(target_user_count=18, spawn_rate=7, user_classes=[User3]) + dispatched_users = next(users_dispatcher) + self.assertDictEqual(_aggregate_dispatched_users(dispatched_users), {"User1": 11, "User2": 0, "User3": 7}) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[0].id), 6) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[1].id), 6) + self.assertEqual(_user_count_on_worker(dispatched_users, worker_nodes[2].id), 6) + + def _aggregate_dispatched_users(d: Dict[str, Dict[str, int]]) -> Dict[str, int]: user_classes = list(next(iter(d.values())).keys()) return {u: sum(d[u] for d in d.values()) for u in user_classes}