Skip to content

Commit

Permalink
Merge pull request #15 from tensorflow/gfile
Browse files Browse the repository at this point in the history
google3 updates
  • Loading branch information
BeenKim committed Oct 4, 2018
2 parents 2ac8c98 + dceb3b6 commit 9efd5e6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cav_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):
if os.path.exists(self.cav_dir):
shutil.rmtree(self.cav_dir)
os.mkdir(self.cav_dir)
with open(self.save_path, 'w') as pkl_file:
with tf.gfile.Open(self.save_path, 'w') as pkl_file:
pickle.dump({
'concepts': self.concepts,
'bottleneck': self.bottleneck,
Expand Down
15 changes: 15 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def adjust_prediction(self, pred_t):
"""
return pred_t

def reshape_activations(self, layer_acts):
"""Reshapes layer activations as needed to feed through the model network.
Override this for models that require reshaping of the activations for use
in TCAV.
Args:
layer_acts: Activations as returned by run_imgs.
Returns:
Activations in model-dependent form; the default is a squeezed array (i.e.
at most one dimensions of size 1).
"""
return np.asarray(layer_acts).squeeze()

@abstractmethod
def label_to_id(self, label):
"""Convert label (string) to index in the logit layer (id)."""
Expand Down
24 changes: 11 additions & 13 deletions tcav_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,27 @@ def make_key(a_dict, key):

def save_np_array(array, path):
"""Save an array as numpy array (loading time is better than pickling."""
with open(path, 'w') as f:
with tf.gfile.Open(path, 'w') as f:
np.save(f, array, allow_pickle=False)


def read_np_array(path):
"""Read a saved numpy array and return."""
with open(path) as f:
with tf.gfile.Open(path) as f:
data = np.load(f)
return data


def read_file(path):
"""Read a file in path ."""
with open(path, 'r') as f:
with tf.gfile.Open(path, 'r') as f:
data = f.read()
return data


def write_file(data, path, mode='w'):
"""Wrtie data to path to cns."""
with open(path, mode) as f:
with tf.gfile.Open(path, mode) as f:
if mode == 'a':
f.write('\n')
f.write(data)
Expand Down Expand Up @@ -153,17 +153,14 @@ def load_images_from_files(filenames, max_imgs=500, return_filenames=False,
else:
return np.array(imgs)

""" highe level overview.
""" high level overview.
get_acts_from_images: run images on a model and return activations.
get_imgs_and_acts_save: loads images from image path and
calls get_acts_from_images to get images
and save them.
"""




def get_acts_from_images(imgs, model, bottleneck_name):
"""Run images in the model to get the activations.
Expand All @@ -175,21 +172,22 @@ def get_acts_from_images(imgs, model, bottleneck_name):
Returns:
numpy array of activations.
"""
return np.asarray(model.run_imgs(imgs, bottleneck_name)).squeeze()
img_acts = model.run_imgs(imgs, bottleneck_name)
return model.reshape_activations(img_acts)


def get_imgs_and_acts_save(model, bottleneck_name, img_paths, acts_path,
img_shape, max_images=500):
"""Get images from files, process acts and saves.
Args:
model: a model instance
bottleneck_name: name of the bottleneck that activations are from
img_paths: where image lives
acts_path: where to store activations
img_shape: shape of the image.
max_images: maximum number of images to save to acts_path
Returns:
success or not.
"""
Expand All @@ -198,7 +196,7 @@ def get_imgs_and_acts_save(model, bottleneck_name, img_paths, acts_path,
tf.logging.info('got %s imgs' % (len(imgs)))
acts = get_acts_from_images(imgs, model, bottleneck_name)
tf.logging.info('Writing acts to {}'.format(acts_path))
with open(acts_path, 'w') as f:
with tf.gfile.Open(acts_path, 'w') as f:
np.save(f, acts, allow_pickle=False)
del acts
del imgs
Expand Down Expand Up @@ -251,7 +249,7 @@ def process_and_load_activations(model, bottleneck_names, concepts,
max_images=max_images)

if bottleneck_name not in acts[concept].keys():
with open(acts_path) as f:
with tf.gfile.Open(acts_path) as f:
acts[concept][bottleneck_name] = np.load(f).squeeze()
tf.logging.info('Loaded {} shape {}'.format(
acts_path,
Expand Down

0 comments on commit 9efd5e6

Please sign in to comment.