diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index e54c3406ff1b7f..e11dceee52e199 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -277,8 +277,8 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) return {} - with cloud_open(tags_csv, "rb") as fp: - csv_reader = csv.reader(fp.read().decode("unicode_escape"), delimiter=",") + with cloud_open(tags_csv, "r") as fp: + csv_reader = csv.reader(fp.read(), delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} return tags @@ -291,15 +291,12 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> if isinstance(hparams, Namespace): hparams = vars(hparams) - # write to a buffer first since cloud_open doesn't support the newline setting - strbuffer = io.StringIO(newline="") - fieldnames = ["key", "value"] - writer = csv.DictWriter(strbuffer, fieldnames=fieldnames) - writer.writerow({"key": "key", "value": "value"}) - for k, v in hparams.items(): - writer.writerow({"key": k, "value": v}) - with cloud_open(tags_csv, "wb") as fp: - fp.write(strbuffer.getvalue().encode("unicode_escape")) + with cloud_open(tags_csv, "w", newline="") as fp: + fieldnames = ["key", "value"] + writer = csv.DictWriter(fp, fieldnames=fieldnames) + writer.writerow({"key": "key", "value": "value"}) + for k, v in hparams.items(): + writer.writerow({"key": k, "value": v}) def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: @@ -345,11 +342,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: hparams = dict(hparams) assert isinstance(hparams, dict) - # cloud_open doesnt support newline settings so write to a buffer first - strbuffer = io.StringIO(newline="") - yaml.dump(hparams, strbuffer) - with cloud_open(config_yaml, "w") as fp: - fp.write(strbuffer.getvalue()) + with cloud_open(config_yaml, "w", newline="") as fp: + yaml.dump(hparams, fp) def convert(val: str) -> Union[int, float, bool, str]: diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 03634138cdc28a..7329213b20d4fe 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -29,19 +29,19 @@ def load(path_or_url: str, map_location=None): return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) -def cloud_open(path: pathlike, mode: str): +def cloud_open(path: pathlike, mode: str, newline:str = None): if not modern_gfile or sys.platform == "win32": log.debug( "tenosrboard.compat gfile does not work on older versions " "of tensorboard normal local file open." ) - return open(path, mode) + return open(path, mode, newline=newline) if sys.platform == "win32": log.debug( "gfile does not handle newlines correctly on windows so remote files are not" "supported falling back to normal local file open." ) - return open(path, mode) + return open(path, mode, newline=newline) try: return gfile.GFile(path, mode) except NotImplementedError as e: