Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
parikls committed Oct 9, 2024
1 parent 3861a9a commit 65f7b7a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 28 deletions.
12 changes: 3 additions & 9 deletions neuro_config_client/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,8 @@ class NodePoolOptions:
available_cpu: float
memory: int
available_memory: int
nvidia_gpu: int | None = None
amd_gpu: int | None = None
intel_gpu: int | None = None
nvidia_gpu_model: str | None = None
amd_gpu_model: str | None = None
intel_gpu_model: str | None = None
gpu: int | None = None
gpu_model: str | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -521,9 +517,7 @@ class ResourcePoolType:
class Resources:
cpu_m: int
memory: int
nvidia_gpu: int = 0
amd_gpu: int = 0
intel_gpu: int = 0
gpu: int = 0


@dataclass(frozen=True)
Expand Down
20 changes: 5 additions & 15 deletions neuro_config_client/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,8 @@ def create_node_pool_options(payload: dict[str, Any]) -> NodePoolOptions:
available_cpu=payload["available_cpu"],
memory=payload["memory"],
available_memory=payload["available_memory"],
nvidia_gpu=payload.get("nvidia_gpu"),
amd_gpu=payload.get("amd_gpu"),
intel_gpu=payload.get("intel_gpu"),
nvidia_gpu_model=payload.get("nvidia_gpu_model"),
amd_gpu_model=payload.get("amd_gpu_model"),
intel_gpu_model=payload.get("intel_gpu_model")
gpu=payload.get("gpu"),
gpu_model=payload.get("gpu_model")
)

@classmethod
Expand Down Expand Up @@ -336,9 +332,7 @@ def create_resources(self, payload: dict[str, Any]) -> Resources:
return Resources(
cpu_m=payload["cpu_m"],
memory=payload["memory"],
nvidia_gpu=payload.get("nvidia_gpu", 0),
amd_gpu=payload.get("amd_gpu", 0),
intel_gpu=payload.get("intel_gpu", 0)
gpu=payload.get("gpu", 0)
)

def create_storage(self, payload: dict[str, Any]) -> StorageConfig:
Expand Down Expand Up @@ -977,12 +971,8 @@ def _create_idle_job(cls, idle_job: IdleJobConfig) -> dict[str, Any]:
@classmethod
def _create_resources(cls, resources: Resources) -> dict[str, Any]:
result = {"cpu_m": resources.cpu_m, "memory": resources.memory}
if resources.nvidia_gpu:
result["nvidia_gpu"] = resources.nvidia_gpu
if resources.amd_gpu:
result["amd_gpu"] = resources.amd_gpu
if resources.intel_gpu:
result["intel_gpu"] = resources.intel_gpu
if resources.gpu:
result["gpu"] = resources.gpu
return result

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,8 @@ def node_pool_options_response(self) -> dict[str, Any]:
"available_cpu": 23,
"memory": 458752,
"available_memory": 452608,
"nvidia_gpu": 4,
"nvidia_gpu_model": "nvidia-tesla-p40",
"gpu": 4,
"gpu_model": "nvidia-tesla-p40",
"extra_info": "will be ignored",
}

Expand All @@ -1287,8 +1287,8 @@ def node_pool_options(self) -> NodePoolOptions:
available_cpu=23,
memory=458752,
available_memory=452608,
nvidia_gpu=4,
nvidia_gpu_model="nvidia-tesla-p40",
gpu=4,
gpu_model="nvidia-tesla-p40",
)

def test_aws_cloud_provider_options(
Expand Down

0 comments on commit 65f7b7a

Please sign in to comment.