diff --git a/cloud/envs/gcp.py b/cloud/envs/gcp.py index b114bc6..a04bff2 100644 --- a/cloud/envs/gcp.py +++ b/cloud/envs/gcp.py @@ -305,7 +305,7 @@ def add(self, *args, **kwargs): return tpu return super().add(*args, **kwargs) - def get(self, preemptible=True, name=None, version='v3-8', zone=None): + def get(self, mode='preemptible', name=None, version='v3-8', zone=None, tpu_type='tpu-node'): tpu = None assert re.match(r"v\d-\d+", version) for tpu in self.resources: @@ -318,20 +318,24 @@ def get(self, preemptible=True, name=None, version='v3-8', zone=None): break else: logger.debug("creating tpu") - tpu = self.up(preemptible=preemptible, name=name, version=version, zone=zone) + tpu = self.up(mode=mode, name=name, version=version, zone=zone, tpu_type=tpu_type) tpu.in_use() return tpu - def _up(self, name, ip, preemptible, version, zone, background): + def _up(self, name, ip, mode, version, zone, background, tpu_type): logger.info("Trying to acquire TPU with name: {} ip: {}".format(name, ip)) - cmd = [ - "gcloud", "compute", "tpus", "create", name, "--range={}".format(ip), - "--accelerator-type={}".format(version), "--version={}".format(self.tf_version), "--network=default" - ] + assert tpu_type in ['tpu-node', 'tpu-vm'], f"TPU type must be one of tpu-node or tpu-vm, got {tpu_type}" + assert mode in ["on_demand", "preemptible", "reserved"], f"TPU mode must be one of on_demand, reserved preemptible - got {mode}" + if tpu_type == 'tpu-vm': + command_insertion = ["gcloud", "alpha", "compute", "tpus", "tpu-vm"] + else: + command_insertion = ["gcloud", "compute", "tpus"] + cmd = command_insertion + ["create", name, "--range={}".format(ip), + "--accelerator-type={}".format(version), "--version={}".format(self.tf_version), "--network=default"] if zone: cmd += ["--zone={}".format(zone)] - if preemptible: - cmd += ["--preemptible"] + if mode != "on_demand" : + cmd += [f"--{mode}"] if background: cmd += ["--async"] @@ -341,17 +345,18 @@ def _up(self, name, ip, preemptible, version, zone, background): raise Exception("Failed to create TPU with name: {} ip: {} error: \n{}".format(name, ip, err)) - def up(self, preemptible=True, background=False, attempts=5, name=None, version='v3-8', zone=None): + def up(self, mode='preemptible', background=False, attempts=5, name=None, version='v3-8', zone=None, tpu_type='tpu-node'): if not name: name = self._new_name() for i in range(attempts): try: tpu = self._up(name, self._new_ip(), - preemptible=preemptible, + mode=mode, version=version, zone=zone, - background=background) + background=background + tpu_type=tpu_type) tpu.manager = self self.resources.append(tpu) return tpu