Skip to content

Commit

Permalink
use is_tensor instead == (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
mishinma authored and arnaudvl committed Dec 27, 2019
1 parent 31fb0b8 commit 08267bd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions alibi_detect/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def state_aegmm(od: OutlierAEGMM) -> Dict:
od
Outlier detector object.
"""
if None in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]:
if not all(tf.is_tensor(_) for _ in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]):
logger.warning('Saving AEGMM detector that has not been fit.')

state_dict = {'threshold': od.threshold,
Expand All @@ -187,7 +187,7 @@ def state_vaegmm(od: OutlierVAEGMM) -> Dict:
od
Outlier detector object.
"""
if None in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]:
if not all(tf.is_tensor(_) for _ in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]):
logger.warning('Saving VAEGMM detector that has not been fit.')

state_dict = {'threshold': od.threshold,
Expand Down Expand Up @@ -608,7 +608,7 @@ def init_od_aegmm(state_dict: Dict,
od.L = state_dict['L']
od.log_det_cov = state_dict['log_det_cov']

if None in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]:
if not all(tf.is_tensor(_) for _ in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]):
logger.warning('Loaded AEGMM detector has not been fit.')

return od
Expand Down Expand Up @@ -639,7 +639,7 @@ def init_od_vaegmm(state_dict: Dict,
od.L = state_dict['L']
od.log_det_cov = state_dict['log_det_cov']

if None in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]:
if not all(tf.is_tensor(_) for _ in [od.phi, od.mu, od.cov, od.L, od.log_det_cov]):
logger.warning('Loaded VAEGMM detector has not been fit.')

return od
Expand Down

0 comments on commit 08267bd

Please sign in to comment.