Skip to content

Commit

Permalink
Allow to drop jobs (#1630)
Browse files Browse the repository at this point in the history
* Add drop job endpoint

* Add support of drop (complete removal) of job

* Fix tests

* Fix tests
  • Loading branch information
romasku authored Jul 12, 2021
1 parent 37757c2 commit 4a582d1
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 6 deletions.
28 changes: 28 additions & 0 deletions alembic/versions/0ee42d5f1908_add_retention_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""add retention index
Revision ID: 0ee42d5f1908
Revises: 1497c0e2f5a2
Create Date: 2021-07-09 16:46:45.289722
"""
import sqlalchemy as sa
from alembic import op


# revision identifiers, used by Alembic.
revision = "0ee42d5f1908"
down_revision = "1497c0e2f5a2"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_index(
"jobs_being_dropped_index",
"jobs",
[sa.text("(((payload ->> 'being_dropped'::text))::boolean)")],
)


def downgrade() -> None:
op.drop_index("jobs_being_dropped_index", table_name="jobs")
58 changes: 52 additions & 6 deletions platform_api/handlers/jobs_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def create_job_response_validator() -> t.Trafaret:
"scheduler_enabled": t.Bool,
"preemptible_node": t.Bool,
"materialized": t.Bool,
"being_dropped": t.Bool,
"logs_removed": t.Bool,
t.Key("is_preemptible", optional=True): t.Bool,
t.Key("is_preemptible_node_required", optional=True): t.Bool,
"pass_config": t.Bool,
Expand Down Expand Up @@ -315,6 +317,14 @@ def _check_exactly_one(payload: Dict[str, Any]) -> Dict[str, Any]:
)


def create_drop_progress_validator() -> t.Trafaret:
return t.Dict(
{
t.Key("logs_removed", optional=True): t.Bool,
}
)


def convert_job_container_to_json(container: Container) -> Dict[str, Any]:
ret: Dict[str, Any] = {
"image": container.image,
Expand Down Expand Up @@ -424,6 +434,8 @@ def convert_job_to_job_response(job: Job) -> Dict[str, Any]:
"restart_policy": str(job.restart_policy),
"privileged": job.privileged,
"materialized": job.materialized,
"being_dropped": job.being_dropped,
"logs_removed": job.logs_removed,
}
if job.name:
response_payload["name"] = job.name
Expand Down Expand Up @@ -499,6 +511,7 @@ def __init__(self, *, app: aiohttp.web.Application, config: Config) -> None:
self._job_update_run_time_validator = (
create_job_update_max_run_time_minutes_validator()
)
self._drop_progress_validator = create_drop_progress_validator()
self._bulk_jobs_response_validator = t.Dict(
{"jobs": t.List(self._job_response_validator)}
)
Expand All @@ -524,6 +537,8 @@ def register(self, app: aiohttp.web.Application) -> None:
"/{job_id}/max_run_time_minutes",
self.handle_put_max_run_time_minutes,
),
aiohttp.web.post("/{job_id}/drop", self.handle_drop_job),
aiohttp.web.post("/{job_id}/drop_progress", self.handle_drop_progress),
)
)

Expand Down Expand Up @@ -890,6 +905,33 @@ async def handle_put_max_run_time_minutes(
else:
raise aiohttp.web.HTTPNoContent()

async def handle_drop_job(
self, request: aiohttp.web.Request
) -> aiohttp.web.StreamResponse:
job = await self._resolve_job(request, "write")
try:
await self._jobs_service.drop_job(job.id)
except JobError as e:
payload = {"error": str(e)}
return aiohttp.web.json_response(
payload, status=aiohttp.web.HTTPBadRequest.status_code
)
else:
raise aiohttp.web.HTTPNoContent()

async def handle_drop_progress(
self, request: aiohttp.web.Request
) -> aiohttp.web.StreamResponse:
job = await self._resolve_job(request, "write")

orig_payload = await request.json()
request_payload = self._drop_progress_validator.check(orig_payload)

await self._jobs_service.drop_progress(
job.id, logs_removed=request_payload.get("logs_removed")
)
raise aiohttp.web.HTTPNoContent()


class JobFilterException(ValueError):
pass
Expand All @@ -906,9 +948,10 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore
statuses = {JobStatus(s) for s in query.getall("status", [])}
tags = set(query.getall("tag", []))
hostname = query.get("hostname")
materialized = None
if "materialized" in query:
materialized = query["materialized"].lower() == "true"
bool_filters = {}
for name in ["materialized", "being_dropped", "logs_removed"]:
if name in query:
bool_filters[name] = query[name].lower() == "true"
if hostname is None:
job_name = self._job_name_validator.check(query.get("name"))
owners = {
Expand All @@ -934,7 +977,7 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore
tags=tags,
since=iso8601.parse_date(since) if since else JobFilter.since,
until=iso8601.parse_date(until) if until else JobFilter.until,
materialized=materialized,
**bool_filters,
)

for key in ("name", "owner", "cluster_name", "since", "until"):
Expand All @@ -945,7 +988,10 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore
job_name, sep, base_owner = label.rpartition(JOB_USER_NAMES_SEPARATOR)
if not sep:
return JobFilter(
statuses=statuses, ids={label}, tags=tags, materialized=materialized
statuses=statuses,
ids={label},
tags=tags,
**bool_filters,
)
job_name = self._job_name_validator.check(job_name)
base_owner = self._base_owner_name_validator.check(base_owner)
Expand All @@ -954,7 +1000,7 @@ def create_from_query(self, query: MultiDictProxy) -> JobFilter: # type: ignore
base_owners={base_owner},
name=job_name,
tags=tags,
materialized=materialized,
**bool_filters,
)


Expand Down
26 changes: 26 additions & 0 deletions platform_api/orchestrator/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ class JobRecord:
last_billed: Optional[datetime] = None
total_price_credits: Decimal = Decimal("0")

# Retention (allows other services as platform-monitoring to cleanup jobs resources)
being_dropped: bool = False
logs_removed: bool = False

# for testing only
allow_empty_cluster_name: bool = False

Expand Down Expand Up @@ -479,6 +483,10 @@ def to_primitive(self) -> Dict[str, Any]:
result["tags"] = self.tags
if self.last_billed:
result["last_billed"] = self.last_billed.isoformat()
if self.being_dropped:
result["being_dropped"] = self.being_dropped
if self.logs_removed:
result["logs_removed"] = self.logs_removed
return result

@classmethod
Expand Down Expand Up @@ -518,6 +526,8 @@ def from_primitive(
last_billed=datetime.fromisoformat(payload["last_billed"])
if "last_billed" in payload
else None,
being_dropped=payload.get("being_dropped", False),
logs_removed=payload.get("logs_removed", False),
)

@staticmethod
Expand Down Expand Up @@ -694,6 +704,22 @@ def materialized(self) -> bool:
def materialized(self, value: bool) -> None:
self._record.materialized = value

@property
def being_dropped(self) -> bool:
return self._record.being_dropped

@being_dropped.setter
def being_dropped(self, value: bool) -> None:
self._record.being_dropped = value

@property
def logs_removed(self) -> bool:
return self._record.logs_removed

@logs_removed.setter
def logs_removed(self, value: bool) -> None:
self._record.logs_removed = value

@property
def schedule_timeout(self) -> Optional[float]:
return self._record.schedule_timeout
Expand Down
19 changes: 19 additions & 0 deletions platform_api/orchestrator/jobs_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,22 @@ async def get_not_billed_jobs(self) -> AsyncIterator[Job]:
) as it:
async for record in it:
yield await self._get_cluster_job(record)

async def drop_job(
self,
job_id: str,
) -> None:
async with self._jobs_storage.try_update_job(job_id) as record:
record.being_dropped = True

async def drop_progress(
self, job_id: str, *, logs_removed: Optional[bool] = None
) -> None:
async with self._jobs_storage.try_update_job(job_id) as record:
if not record.being_dropped:
raise JobError(f"Job {job_id} is not being dropped")
if logs_removed:
record.logs_removed = logs_removed
all_resources_cleaned = record.logs_removed
if all_resources_cleaned:
await self._jobs_storage.drop_job(job_id)
10 changes: 10 additions & 0 deletions platform_api/orchestrator/jobs_storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class JobFilter:
until: datetime = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
materialized: Optional[bool] = None
fully_billed: Optional[bool] = None
being_dropped: Optional[bool] = None
logs_removed: Optional[bool] = None

def check(self, job: JobRecord) -> bool:
if self.statuses and job.status not in self.statuses:
Expand Down Expand Up @@ -91,6 +93,10 @@ def check(self, job: JobRecord) -> bool:
return self.materialized == job.materialized
if self.fully_billed is not None:
return self.fully_billed == job.fully_billed
if self.being_dropped is not None:
return self.being_dropped == job.being_dropped
if self.logs_removed is not None:
return self.logs_removed == job.logs_removed
return True


Expand Down Expand Up @@ -145,6 +151,10 @@ async def set_job(self, job: JobRecord) -> None:
async def get_job(self, job_id: str) -> JobRecord:
pass

@abstractmethod
async def drop_job(self, job_id: str) -> None:
pass

@abstractmethod
def try_update_job(self, job_id: str) -> AsyncContextManager[JobRecord]:
pass
Expand Down
5 changes: 5 additions & 0 deletions platform_api/orchestrator/jobs_storage/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ async def get_job(self, job_id: str) -> JobRecord:
raise JobError(f"no such job {job_id}")
return self._parse_job_payload(payload)

async def drop_job(self, job_id: str) -> None:
payload = self._job_records.pop(job_id, None)
if payload is None:
raise JobError(f"no such job {job_id}")

@asynccontextmanager
async def try_update_job(self, job_id: str) -> AsyncIterator[JobRecord]:
job = await self.get_job(job_id)
Expand Down
26 changes: 26 additions & 0 deletions platform_api/orchestrator/jobs_storage/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ async def get_job(self, job_id: str) -> JobRecord:
record = await self._select_row(job_id)
return self._record_to_job(record)

async def drop_job(self, job_id: str) -> None:
query = (
self._tables.jobs.delete()
.where(self._tables.jobs.c.id == job_id)
.returning(self._tables.jobs.c.id)
)
result = await self._fetchrow(query)
if result is None:
raise JobError(f"no such job {job_id}")

@asynccontextmanager
async def try_create_job(self, job: JobRecord) -> AsyncIterator[JobRecord]:
# No need to do any checks -- INSERT cannot be executed twice
Expand Down Expand Up @@ -496,6 +506,18 @@ def filter_fully_billed(self, fully_billed: bool) -> None:
== fully_billed
)

def filter_being_dropped(self, being_dropped: bool) -> None:
self._clauses.append(
self._tables.jobs.c.payload["being_dropped"].astext.cast(Boolean)
== being_dropped
)

def filter_logs_removed(self, logs_removed: bool) -> None:
self._clauses.append(
self._tables.jobs.c.payload["logs_removed"].astext.cast(Boolean)
== logs_removed
)

def build(self) -> sasql.ClauseElement:
return and_(*self._clauses)

Expand All @@ -522,6 +544,10 @@ def by_job_filter(
builder.filter_materialized(job_filter.materialized)
if job_filter.fully_billed is not None:
builder.filter_fully_billed(job_filter.fully_billed)
if job_filter.being_dropped is not None:
builder.filter_being_dropped(job_filter.being_dropped)
if job_filter.logs_removed is not None:
builder.filter_logs_removed(job_filter.logs_removed)
builder.filter_since(job_filter.since)
builder.filter_until(job_filter.until)
return builder.build()
32 changes: 32 additions & 0 deletions tests/integration/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,38 @@ async def delete_job(
response.status == HTTPNoContent.status_code
), await response.text()

async def drop_job(
self,
job_id: str,
assert_success: bool = True,
headers: Optional[Dict[str, str]] = None,
) -> None:
url = self._api_config.generate_job_url(job_id) + "/drop"
async with self._client.post(url, headers=headers or self._headers) as response:
if assert_success:
assert (
response.status == HTTPNoContent.status_code
), await response.text()

async def drop_progress(
self,
job_id: str,
logs_removed: Optional[bool] = None,
assert_success: bool = True,
headers: Optional[Dict[str, str]] = None,
) -> None:
url = self._api_config.generate_job_url(job_id) + "/drop_progress"
payload = {}
if logs_removed is not None:
payload["logs_removed"] = logs_removed
async with self._client.post(
url, json=payload, headers=headers or self._headers
) as response:
if assert_success:
assert (
response.status == HTTPNoContent.status_code
), await response.text()


@pytest.fixture
async def jobs_client_factory(
Expand Down
Loading

0 comments on commit 4a582d1

Please sign in to comment.