From 4a582d12ae23723717fbf28e72d9af84e4b21f16 Mon Sep 17 00:00:00 2001 From: Roman Skurikhin Date: Mon, 12 Jul 2021 17:11:37 +0300 Subject: [PATCH] Allow to drop jobs (#1630) * Add drop job endpoint * Add support of drop (complete removal) of job * Fix tests * Fix tests --- .../0ee42d5f1908_add_retention_index.py | 28 +++++++++ platform_api/handlers/jobs_handler.py | 58 +++++++++++++++++-- platform_api/orchestrator/job.py | 26 +++++++++ platform_api/orchestrator/jobs_service.py | 19 ++++++ .../orchestrator/jobs_storage/base.py | 10 ++++ .../orchestrator/jobs_storage/in_memory.py | 5 ++ .../orchestrator/jobs_storage/postgres.py | 26 +++++++++ tests/integration/api.py | 32 ++++++++++ tests/integration/test_api.py | 39 +++++++++++++ tests/integration/test_jobs_storage.py | 46 +++++++++++++++ tests/unit/test_job_rest_validator.py | 12 ++++ tests/unit/test_models.py | 18 ++++++ 12 files changed, 313 insertions(+), 6 deletions(-) create mode 100644 alembic/versions/0ee42d5f1908_add_retention_index.py diff --git a/alembic/versions/0ee42d5f1908_add_retention_index.py b/alembic/versions/0ee42d5f1908_add_retention_index.py new file mode 100644 index 000000000..a68e69f2a --- /dev/null +++ b/alembic/versions/0ee42d5f1908_add_retention_index.py @@ -0,0 +1,28 @@ +"""add retention index + +Revision ID: 0ee42d5f1908 +Revises: 1497c0e2f5a2 +Create Date: 2021-07-09 16:46:45.289722 + +""" +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "0ee42d5f1908" +down_revision = "1497c0e2f5a2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_index( + "jobs_being_dropped_index", + "jobs", + [sa.text("(((payload ->> 'being_dropped'::text))::boolean)")], + ) + + +def downgrade() -> None: + op.drop_index("jobs_being_dropped_index", table_name="jobs") diff --git a/platform_api/handlers/jobs_handler.py b/platform_api/handlers/jobs_handler.py index e8788926b..e65b6abe2 100644 --- a/platform_api/handlers/jobs_handler.py +++ b/platform_api/handlers/jobs_handler.py @@ -256,6 +256,8 @@ def create_job_response_validator() -> t.Trafaret: "scheduler_enabled": t.Bool, "preemptible_node": t.Bool, "materialized": t.Bool, + "being_dropped": t.Bool, + "logs_removed": t.Bool, t.Key("is_preemptible", optional=True): t.Bool, t.Key("is_preemptible_node_required", optional=True): t.Bool, "pass_config": t.Bool, @@ -315,6 +317,14 @@ def _check_exactly_one(payload: Dict[str, Any]) -> Dict[str, Any]: ) +def create_drop_progress_validator() -> t.Trafaret: + return t.Dict( + { + t.Key("logs_removed", optional=True): t.Bool, + } + ) + + def convert_job_container_to_json(container: Container) -> Dict[str, Any]: ret: Dict[str, Any] = { "image": container.image, @@ -424,6 +434,8 @@ def convert_job_to_job_response(job: Job) -> Dict[str, Any]: "restart_policy": str(job.restart_policy), "privileged": job.privileged, "materialized": job.materialized, + "being_dropped": job.being_dropped, + "logs_removed": job.logs_removed, } if job.name: response_payload["name"] = job.name @@ -499,6 +511,7 @@ def __init__(self, *, app: aiohttp.web.Application, config: Config) -> None: self._job_update_run_time_validator = ( create_job_update_max_run_time_minutes_validator() ) + self._drop_progress_validator = create_drop_progress_validator() self._bulk_jobs_response_validator = t.Dict( {"jobs": t.List(self._job_response_validator)} ) @@ -524,6 +537,8 @@ def register(self, app: aiohttp.web.Application) -> None: "/{job_id}/max_run_time_minutes", self.handle_put_max_run_time_minutes, ), + aiohttp.web.post("/{job_id}/drop", self.handle_drop_job), + aiohttp.web.post("/{job_id}/drop_progress", self.handle_drop_progress), ) ) @@ -890,6 +905,33 @@ async def handle_put_max_run_time_minutes( else: raise aiohttp.web.HTTPNoContent() + async def handle_drop_job( + self, request: aiohttp.web.Request + ) -> aiohttp.web.StreamResponse: + job = await self._resolve_job(request, "write") + try: + await self._jobs_service.drop_job(job.id) + except JobError as e: + payload = {"error": str(e)} + return aiohttp.web.json_response( + payload, status=aiohttp.web.HTTPBadRequest.status_code + ) + else: + raise aiohttp.web.HTTPNoContent() + + async def handle_drop_progress( + self, request: aiohttp.web.Request + ) -> aiohttp.web.StreamResponse: + job = await self._resolve_job(request, "write") + + orig_payload = await request.json() + request_payload = self._drop_progress_validator.check(orig_payload) + + await self._jobs_service.drop_progress( + job.id, logs_removed=request_payload.get("logs_removed") + ) + raise aiohttp.web.HTTPNoContent() + class JobFilterException(ValueError): pass @@ -906,9 +948,10 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore statuses = {JobStatus(s) for s in query.getall("status", [])} tags = set(query.getall("tag", [])) hostname = query.get("hostname") - materialized = None - if "materialized" in query: - materialized = query["materialized"].lower() == "true" + bool_filters = {} + for name in ["materialized", "being_dropped", "logs_removed"]: + if name in query: + bool_filters[name] = query[name].lower() == "true" if hostname is None: job_name = self._job_name_validator.check(query.get("name")) owners = { @@ -934,7 +977,7 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore tags=tags, since=iso8601.parse_date(since) if since else JobFilter.since, until=iso8601.parse_date(until) if until else JobFilter.until, - materialized=materialized, + **bool_filters, ) for key in ("name", "owner", "cluster_name", "since", "until"): @@ -945,7 +988,10 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore job_name, sep, base_owner = label.rpartition(JOB_USER_NAMES_SEPARATOR) if not sep: return JobFilter( - statuses=statuses, ids={label}, tags=tags, materialized=materialized + statuses=statuses, + ids={label}, + tags=tags, + **bool_filters, ) job_name = self._job_name_validator.check(job_name) base_owner = self._base_owner_name_validator.check(base_owner) @@ -954,7 +1000,7 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore base_owners={base_owner}, name=job_name, tags=tags, - materialized=materialized, + **bool_filters, ) diff --git a/platform_api/orchestrator/job.py b/platform_api/orchestrator/job.py index 3e331cf94..c0e4a8d8b 100644 --- a/platform_api/orchestrator/job.py +++ b/platform_api/orchestrator/job.py @@ -297,6 +297,10 @@ class JobRecord: last_billed: Optional[datetime] = None total_price_credits: Decimal = Decimal("0") + # Retention (allows other services as platform-monitoring to cleanup jobs resources) + being_dropped: bool = False + logs_removed: bool = False + # for testing only allow_empty_cluster_name: bool = False @@ -479,6 +483,10 @@ def to_primitive(self) -> Dict[str, Any]: result["tags"] = self.tags if self.last_billed: result["last_billed"] = self.last_billed.isoformat() + if self.being_dropped: + result["being_dropped"] = self.being_dropped + if self.logs_removed: + result["logs_removed"] = self.logs_removed return result @classmethod @@ -518,6 +526,8 @@ def from_primitive( last_billed=datetime.fromisoformat(payload["last_billed"]) if "last_billed" in payload else None, + being_dropped=payload.get("being_dropped", False), + logs_removed=payload.get("logs_removed", False), ) @staticmethod @@ -694,6 +704,22 @@ def materialized(self) -> bool: def materialized(self, value: bool) -> None: self._record.materialized = value + @property + def being_dropped(self) -> bool: + return self._record.being_dropped + + @being_dropped.setter + def being_dropped(self, value: bool) -> None: + self._record.being_dropped = value + + @property + def logs_removed(self) -> bool: + return self._record.logs_removed + + @logs_removed.setter + def logs_removed(self, value: bool) -> None: + self._record.logs_removed = value + @property def schedule_timeout(self) -> Optional[float]: return self._record.schedule_timeout diff --git a/platform_api/orchestrator/jobs_service.py b/platform_api/orchestrator/jobs_service.py index 7b5be8bb8..4687a9d17 100644 --- a/platform_api/orchestrator/jobs_service.py +++ b/platform_api/orchestrator/jobs_service.py @@ -454,3 +454,22 @@ async def get_not_billed_jobs(self) -> AsyncIterator[Job]: ) as it: async for record in it: yield await self._get_cluster_job(record) + + async def drop_job( + self, + job_id: str, + ) -> None: + async with self._jobs_storage.try_update_job(job_id) as record: + record.being_dropped = True + + async def drop_progress( + self, job_id: str, *, logs_removed: Optional[bool] = None + ) -> None: + async with self._jobs_storage.try_update_job(job_id) as record: + if not record.being_dropped: + raise JobError(f"Job {job_id} is not being dropped") + if logs_removed: + record.logs_removed = logs_removed + all_resources_cleaned = record.logs_removed + if all_resources_cleaned: + await self._jobs_storage.drop_job(job_id) diff --git a/platform_api/orchestrator/jobs_storage/base.py b/platform_api/orchestrator/jobs_storage/base.py index a1ab99c12..082a4a7c3 100644 --- a/platform_api/orchestrator/jobs_storage/base.py +++ b/platform_api/orchestrator/jobs_storage/base.py @@ -60,6 +60,8 @@ class JobFilter: until: datetime = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) materialized: Optional[bool] = None fully_billed: Optional[bool] = None + being_dropped: Optional[bool] = None + logs_removed: Optional[bool] = None def check(self, job: JobRecord) -> bool: if self.statuses and job.status not in self.statuses: @@ -91,6 +93,10 @@ def check(self, job: JobRecord) -> bool: return self.materialized == job.materialized if self.fully_billed is not None: return self.fully_billed == job.fully_billed + if self.being_dropped is not None: + return self.being_dropped == job.being_dropped + if self.logs_removed is not None: + return self.logs_removed == job.logs_removed return True @@ -145,6 +151,10 @@ async def set_job(self, job: JobRecord) -> None: async def get_job(self, job_id: str) -> JobRecord: pass + @abstractmethod + async def drop_job(self, job_id: str) -> None: + pass + @abstractmethod def try_update_job(self, job_id: str) -> AsyncContextManager[JobRecord]: pass diff --git a/platform_api/orchestrator/jobs_storage/in_memory.py b/platform_api/orchestrator/jobs_storage/in_memory.py index e50ad284f..ea8a24c1b 100644 --- a/platform_api/orchestrator/jobs_storage/in_memory.py +++ b/platform_api/orchestrator/jobs_storage/in_memory.py @@ -55,6 +55,11 @@ async def get_job(self, job_id: str) -> JobRecord: raise JobError(f"no such job {job_id}") return self._parse_job_payload(payload) + async def drop_job(self, job_id: str) -> None: + payload = self._job_records.pop(job_id, None) + if payload is None: + raise JobError(f"no such job {job_id}") + @asynccontextmanager async def try_update_job(self, job_id: str) -> AsyncIterator[JobRecord]: job = await self.get_job(job_id) diff --git a/platform_api/orchestrator/jobs_storage/postgres.py b/platform_api/orchestrator/jobs_storage/postgres.py index 0c6059aab..280294ac6 100644 --- a/platform_api/orchestrator/jobs_storage/postgres.py +++ b/platform_api/orchestrator/jobs_storage/postgres.py @@ -188,6 +188,16 @@ async def get_job(self, job_id: str) -> JobRecord: record = await self._select_row(job_id) return self._record_to_job(record) + async def drop_job(self, job_id: str) -> None: + query = ( + self._tables.jobs.delete() + .where(self._tables.jobs.c.id == job_id) + .returning(self._tables.jobs.c.id) + ) + result = await self._fetchrow(query) + if result is None: + raise JobError(f"no such job {job_id}") + @asynccontextmanager async def try_create_job(self, job: JobRecord) -> AsyncIterator[JobRecord]: # No need to do any checks -- INSERT cannot be executed twice @@ -496,6 +506,18 @@ def filter_fully_billed(self, fully_billed: bool) -> None: == fully_billed ) + def filter_being_dropped(self, being_dropped: bool) -> None: + self._clauses.append( + self._tables.jobs.c.payload["being_dropped"].astext.cast(Boolean) + == being_dropped + ) + + def filter_logs_removed(self, logs_removed: bool) -> None: + self._clauses.append( + self._tables.jobs.c.payload["logs_removed"].astext.cast(Boolean) + == logs_removed + ) + def build(self) -> sasql.ClauseElement: return and_(*self._clauses) @@ -522,6 +544,10 @@ def by_job_filter( builder.filter_materialized(job_filter.materialized) if job_filter.fully_billed is not None: builder.filter_fully_billed(job_filter.fully_billed) + if job_filter.being_dropped is not None: + builder.filter_being_dropped(job_filter.being_dropped) + if job_filter.logs_removed is not None: + builder.filter_logs_removed(job_filter.logs_removed) builder.filter_since(job_filter.since) builder.filter_until(job_filter.until) return builder.build() diff --git a/tests/integration/api.py b/tests/integration/api.py index 9232c7bba..bcd3b2f21 100644 --- a/tests/integration/api.py +++ b/tests/integration/api.py @@ -317,6 +317,38 @@ async def delete_job( response.status == HTTPNoContent.status_code ), await response.text() + async def drop_job( + self, + job_id: str, + assert_success: bool = True, + headers: Optional[Dict[str, str]] = None, + ) -> None: + url = self._api_config.generate_job_url(job_id) + "/drop" + async with self._client.post(url, headers=headers or self._headers) as response: + if assert_success: + assert ( + response.status == HTTPNoContent.status_code + ), await response.text() + + async def drop_progress( + self, + job_id: str, + logs_removed: Optional[bool] = None, + assert_success: bool = True, + headers: Optional[Dict[str, str]] = None, + ) -> None: + url = self._api_config.generate_job_url(job_id) + "/drop_progress" + payload = {} + if logs_removed is not None: + payload["logs_removed"] = logs_removed + async with self._client.post( + url, json=payload, headers=headers or self._headers + ) as response: + if assert_success: + assert ( + response.status == HTTPNoContent.status_code + ), await response.text() + @pytest.fixture async def jobs_client_factory( diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 62250366e..28e8304f4 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -4285,6 +4285,35 @@ async def test_delete_not_exist( result = await response.json() assert result["error"] == f"no such job {job_id}" + @pytest.mark.asyncio + async def test_drop_job( + self, + api: ApiConfig, + client: aiohttp.ClientSession, + job_submit: Dict[str, Any], + jobs_client: JobsClient, + regular_user: _User, + ) -> None: + url = api.jobs_base_url + async with client.post( + url, headers=regular_user.headers, json=job_submit + ) as response: + assert response.status == HTTPAccepted.status_code, await response.text() + result = await response.json() + assert result["status"] in ["pending"] + job_id = result["id"] + await jobs_client.long_polling_by_job_id(job_id=job_id, status="succeeded") + await jobs_client.drop_job(job_id=job_id) + + jobs = await jobs_client.get_all_jobs() + assert len(jobs) == 1 + assert jobs[0]["being_dropped"] + assert not jobs[0]["logs_removed"] + await jobs_client.drop_progress(job_id=job_id, logs_removed=True) + + jobs = await jobs_client.get_all_jobs() + assert len(jobs) == 0 + @pytest.mark.asyncio async def test_create_validation_failure( self, api: ApiConfig, client: aiohttp.ClientSession, regular_user: _User @@ -4463,6 +4492,8 @@ async def test_create_with_custom_volumes( "uri": f"job://test-cluster/{regular_user.name}/{job_id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } response_payload = await jobs_client.long_polling_by_job_id( @@ -4510,6 +4541,8 @@ async def test_create_with_custom_volumes( "uri": f"job://test-cluster/{regular_user.name}/{job_id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } @pytest.mark.asyncio @@ -4603,6 +4636,8 @@ async def test_job_failed( "uri": f"job://test-cluster/{regular_user.name}/{job_id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } @pytest.mark.asyncio @@ -4708,6 +4743,8 @@ async def test_create_gpu_model( "uri": f"job://test-cluster/{regular_user.name}/{job_id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } @pytest.mark.asyncio @@ -4807,6 +4844,8 @@ async def test_create_tpu_model( "uri": f"job://test-cluster/{regular_user.name}/{job_id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } diff --git a/tests/integration/test_jobs_storage.py b/tests/integration/test_jobs_storage.py index a8fe0a249..5dddd30a9 100644 --- a/tests/integration/test_jobs_storage.py +++ b/tests/integration/test_jobs_storage.py @@ -100,6 +100,24 @@ async def test_set_get(self, storage: JobsStorage) -> None: assert job.id == original_job.id assert job.status == original_job.status + @pytest.mark.asyncio + async def test_drop_job(self, storage: JobsStorage) -> None: + original_job = self._create_pending_job() + await storage.set_job(original_job) + + job = await storage.get_job(original_job.id) + assert job.id == original_job.id + + await storage.drop_job(original_job.id) + with pytest.raises(JobError): + await storage.get_job(original_job.id) + + @pytest.mark.asyncio + async def test_drop_unexisting_job(self, storage: JobsStorage) -> None: + original_job = self._create_pending_job() + with pytest.raises(JobError): + await storage.drop_job(original_job.id) + @pytest.mark.asyncio async def test_try_create_job__no_name__ok(self, storage: JobsStorage) -> None: @@ -958,6 +976,34 @@ async def test_get_all_filter_by_fully_billed(self, storage: JobsStorage) -> Non job_ids = [job.id for job in await storage.get_all_jobs(job_filter)] assert job_ids == [jobs[0].id, jobs[2].id] + @pytest.mark.asyncio + async def test_get_all_filter_by_being_dropped(self, storage: JobsStorage) -> None: + jobs = [ + self._create_job(being_dropped=True), + self._create_job(being_dropped=False), + self._create_job(being_dropped=True), + ] + for job in jobs: + async with storage.try_create_job(job): + pass + job_filter = JobFilter(being_dropped=True) + job_ids = [job.id for job in await storage.get_all_jobs(job_filter)] + assert job_ids == [jobs[0].id, jobs[2].id] + + @pytest.mark.asyncio + async def test_get_all_filter_by_logs_removed(self, storage: JobsStorage) -> None: + jobs = [ + self._create_job(logs_removed=True), + self._create_job(logs_removed=False), + self._create_job(logs_removed=True), + ] + for job in jobs: + async with storage.try_create_job(job): + pass + job_filter = JobFilter(logs_removed=True) + job_ids = [job.id for job in await storage.get_all_jobs(job_filter)] + assert job_ids == [jobs[0].id, jobs[2].id] + @pytest.mark.asyncio async def test_get_all_filter_by_cluster_and_owner( self, storage: JobsStorage diff --git a/tests/unit/test_job_rest_validator.py b/tests/unit/test_job_rest_validator.py index 4ccaf1610..856aa35e0 100644 --- a/tests/unit/test_job_rest_validator.py +++ b/tests/unit/test_job_rest_validator.py @@ -270,6 +270,8 @@ def test_job_details_with_name(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) @@ -310,6 +312,8 @@ def test_job_empty_description(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) @@ -348,6 +352,8 @@ def test_job_details_without_name(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) @@ -390,6 +396,8 @@ def test_with_entrypoint_and_cmd(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) @@ -430,6 +438,8 @@ def test_with_absolute_working_dir(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) @@ -512,6 +522,8 @@ def test_with_max_run_time_minutes(self) -> None: "uri": "job://cluster-name/tests/test-job-id", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } validator = create_job_response_validator() assert validator.check(response) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index d60ffd668..ee2d9e563 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -850,6 +850,18 @@ def test_create_from_query(self) -> None: query = MultiDict([("materialized", "False")]) assert factory(query) == JobFilter(materialized=False) + query = MultiDict([("being_dropped", "True")]) + assert factory(query) == JobFilter(being_dropped=True) + + query = MultiDict([("being_dropped", "False")]) + assert factory(query) == JobFilter(being_dropped=False) + + query = MultiDict([("logs_removed", "True")]) + assert factory(query) == JobFilter(logs_removed=True) + + query = MultiDict([("logs_removed", "False")]) + assert factory(query) == JobFilter(logs_removed=False) + def test_create_from_query_with_status(self) -> None: factory = JobFilterFactory().create_from_query @@ -1539,6 +1551,8 @@ async def test_job_to_job_response(mock_orchestrator: MockOrchestrator) -> None: "uri": f"job://test-cluster/compute/{job.id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } @@ -1645,6 +1659,8 @@ async def test_job_to_job_response_with_job_name_and_http_exposed( "uri": f"job://test-cluster/{owner_name}/{job.id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, } @@ -1711,6 +1727,8 @@ async def test_job_to_job_response_with_job_name_and_http_exposed_too_long_name( "uri": f"job://test-cluster/{owner_name}/{job.id}", "restart_policy": "never", "privileged": False, + "being_dropped": False, + "logs_removed": False, }