Skip to content

Commit

Permalink
Removed dependency on pandas, instead use generic csv (#736)
Browse files Browse the repository at this point in the history
* removed dependency on pandas, instead use generic csv

* remove mnist files, pushed by accident

* added docstring and small fixes

* Update memory.py

* fixed path

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
SkafteNicki and williamFalcon committed Jan 29, 2020
1 parent deffbab commit 9a6838d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 30 deletions.
15 changes: 8 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
import csv


import pandas as pd
import torch
import torch.distributed as dist
#

from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
Expand Down Expand Up @@ -1217,10 +1216,12 @@ def load_hparams_from_tags_csv(tags_csv):
logging.warning(f'Missing Tags: {tags_csv}.')
return Namespace()

tags_df = pd.read_csv(tags_csv)
dic = tags_df.to_dict(orient='records')
ns_dict = {row['key']: convert(row['value']) for row in dic}
ns = Namespace(**ns_dict)
tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
ns = Namespace(**tags)
return ns


Expand Down
68 changes: 51 additions & 17 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from subprocess import PIPE

import numpy as np
import pandas as pd
import torch


Expand Down Expand Up @@ -146,24 +145,14 @@ def make_summary(self):
Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
'''

cols = ['Name', 'Type', 'Params']
if self.model.example_input_array is not None:
cols.extend(['In_sizes', 'Out_sizes'])

df = pd.DataFrame(np.zeros((len(self.layer_names), len(cols))))
df.columns = cols

df['Name'] = self.layer_names
df['Type'] = self.layer_types
df['Params'] = self.param_nums
df['Params'] = df['Params'].map(get_human_readable_count)

arrays = [['Name', self.layer_names],
['Type', self.layer_types],
['Params', list(map(get_human_readable_count, self.param_nums))]]
if self.model.example_input_array is not None:
df['In_sizes'] = self.in_sizes
df['Out_sizes'] = self.out_sizes
arrays.append(['In sizes', self.in_sizes])
arrays.append(['Out sizes', self.out_sizes])

self.summary = df
self.summary = _format_summary_table(*arrays)
return

def summarize(self):
Expand All @@ -176,6 +165,51 @@ def summarize(self):
self.make_summary()


def _format_summary_table(*cols):
'''
Takes in a number of arrays, each specifying a column in
the summary table, and combines them all into one big
string defining the summary table that are nicely formatted.
'''
n_rows = len(cols[0][1])
n_cols = 1 + len(cols)

# Layer counter
counter = list(map(str, list(range(n_rows))))
counter_len = max([len(c) for c in counter])

# Get formatting length of each column
length = []
for c in cols:
str_l = len(c[0]) # default length is header length
for a in c[1]:
if isinstance(a, np.ndarray):
array_string = '[' + ', '.join([str(j) for j in a]) + ']'
str_l = max(len(array_string), str_l)
else:
str_l = max(len(a), str_l)
length.append(str_l)

# Formatting
s = '{:<{}}'
full_length = sum(length) + 3 * n_cols
header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)]

# Summary = header + divider + Rest of table
summary = ' | '.join(header) + '\n' + '-' * full_length
for i in range(n_rows):
line = s.format(counter[i], counter_len)
for c, l in zip(cols, length):
if isinstance(c[1][i], np.ndarray):
array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']'
line += ' | ' + array_string + ' ' * (l - len(array_string))
else:
line += ' | ' + s.format(c[1][i], l)
summary += '\n' + line

return summary


def print_mem_stack(): # pragma: no cover
for obj in gc.get_objects():
try:
Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/logging/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pkg_resources import parse_version

import torch
import pandas as pd
import csv
from torch.utils.tensorboard import SummaryWriter

from .base import LightningLoggerBase, rank_zero_only
Expand Down Expand Up @@ -108,12 +108,17 @@ def save(self):
dir_path = os.path.join(self.save_dir, self.name, 'version_%s' % self.version)
if not os.path.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS)

# save the metatags file
df = pd.DataFrame({'key': list(self.tags.keys()),
'value': list(self.tags.values())})
df.to_csv(meta_tags_path, index=False)
with open(meta_tags_path, 'w', newline='') as csvfile:
fieldnames = ['key', 'value']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
for k, v in self.tags.items():
writer.writerow({'key': k, 'value': v})

@rank_zero_only
def finalize(self, status):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ tqdm>=4.35.0
numpy>=1.16.4
torch>=1.1
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
pandas>=0.24 # lower version do not support py3.7
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
future>=0.17.1 # required for builtins in setup.py

0 comments on commit 9a6838d

Please sign in to comment.