Skip to content

Commit

Permalink
[pycaffe] pep8 coord map
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed Feb 28, 2016
1 parent dd1c9b5 commit 2ad9995
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions python/caffe/coord_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,28 @@
from caffe import layers as L

PASS_THROUGH_LAYERS = ['AbsVal', 'BatchNorm', 'Bias', 'BNLL', 'Dropout',
'Eltwise', 'ELU', 'Log', 'LRN', 'Exp', 'MVN', 'Power', 'ReLU', 'PReLU',
'Scale', 'Sigmoid', 'Split', 'TanH', 'Threshold']
'Eltwise', 'ELU', 'Log', 'LRN', 'Exp', 'MVN', 'Power',
'ReLU', 'PReLU', 'Scale', 'Sigmoid', 'Split', 'TanH',
'Threshold']


def conv_params(fn):
params = fn.params.get('convolution_param', fn.params)
axis = params.get('axis', 1)
ks = np.array(params['kernel_size'], ndmin=1)
dilation = np.array(params.get('dilation', 1), ndmin=1)
assert len({'pad_h', 'pad_w', 'kernel_h', 'kernel_w', 'stride_h',
'stride_w'} & set(fn.params)) == 0, \
'cropping does not support legacy _h/_w params'
'stride_w'} & set(fn.params)) == 0, \
'cropping does not support legacy _h/_w params'
return (axis, np.array(params.get('stride', 1), ndmin=1),
(ks - 1) * dilation + 1,
np.array(params.get('pad', 0), ndmin=1))


class UndefinedMapException(Exception):
pass


def coord_map(fn):
if fn.type_name in ['Convolution', 'Pooling', 'Im2col']:
axis, stride, ks, pad = conv_params(fn)
Expand All @@ -36,9 +40,11 @@ def coord_map(fn):
else:
raise UndefinedMapException


class AxisMismatchException(Exception):
pass


def compose((ax1, a1, b1), (ax2, a2, b2)):
if ax1 is None:
ax = ax2
Expand All @@ -48,9 +54,11 @@ def compose((ax1, a1, b1), (ax2, a2, b2)):
raise AxisMismatchException
return ax, a1 * a2, a1 * b2 + b1


def inverse((ax, a, b)):
return ax, 1 / a, -b / a


def coord_map_from_to(top_from, top_to):
# We need to find a common ancestor of top_from and top_to.
# We'll assume that all ancestors are equivalent here (otherwise the graph
Expand Down Expand Up @@ -84,12 +92,16 @@ def coord_map_from_to(top_from, top_to):
continue

# if we got here, we did not find a blob in common
raise RuntimeError, 'Could not compute map between tops; are they connected ' \
'by spatial layers?'
raise RuntimeError('Could not compute map between tops; are they '
'connected by spatial layers?')


def crop(top_from, top_to):
ax, a, b = coord_map_from_to(top_from, top_to)
assert (a == 1).all(), 'scale mismatch on crop (a = {})'.format(a)
assert (b <= 0).all(), 'cannot crop negative width (b = {})'.format(b)
assert (np.round(b) == b).all(), 'cannot crop noninteger width (b = {})'.format(b)
return L.Crop(top_from, top_to, crop_param=dict(axis=ax, crop=list(-np.round(b).astype(int))))
assert (np.round(b) == b).all(),
'cannot crop noninteger width (b = {})'.format(b)
return L.Crop(top_from, top_to,
crop_param=dict(axis=ax,
crop=list(-np.round(b).astype(int))))

0 comments on commit 2ad9995

Please sign in to comment.