Skip to content

Commit

Permalink
support newline setting in cloud_io
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 7, 2020
1 parent c77816c commit 4d9d028
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
26 changes: 10 additions & 16 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with cloud_open(tags_csv, "rb") as fp:
csv_reader = csv.reader(fp.read().decode("unicode_escape"), delimiter=",")
with cloud_open(tags_csv, "r") as fp:
csv_reader = csv.reader(fp.read(), delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags
Expand All @@ -291,15 +291,12 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
if isinstance(hparams, Namespace):
hparams = vars(hparams)

# write to a buffer first since cloud_open doesn't support the newline setting
strbuffer = io.StringIO(newline="")
fieldnames = ["key", "value"]
writer = csv.DictWriter(strbuffer, fieldnames=fieldnames)
writer.writerow({"key": "key", "value": "value"})
for k, v in hparams.items():
writer.writerow({"key": k, "value": v})
with cloud_open(tags_csv, "wb") as fp:
fp.write(strbuffer.getvalue().encode("unicode_escape"))
with cloud_open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({"key": "key", "value": "value"})
for k, v in hparams.items():
writer.writerow({"key": k, "value": v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -345,11 +342,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
assert isinstance(hparams, dict)

# cloud_open doesnt support newline settings so write to a buffer first
strbuffer = io.StringIO(newline="")
yaml.dump(hparams, strbuffer)
with cloud_open(config_yaml, "w") as fp:
fp.write(strbuffer.getvalue())
with cloud_open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


def convert(val: str) -> Union[int, float, bool, str]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def load(path_or_url: str, map_location=None):
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def cloud_open(path: pathlike, mode: str):
def cloud_open(path: pathlike, mode: str, newline:str = None):
if not modern_gfile or sys.platform == "win32":
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard normal local file open."
)
return open(path, mode)
return open(path, mode, newline=newline)
if sys.platform == "win32":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
"supported falling back to normal local file open."
)
return open(path, mode)
return open(path, mode, newline=newline)
try:
return gfile.GFile(path, mode)
except NotImplementedError as e:
Expand Down

0 comments on commit 4d9d028

Please sign in to comment.