diff --git a/aerosandbox/numpy/surrogate_model_tools.py b/aerosandbox/numpy/surrogate_model_tools.py index fa145cc3..d09e0862 100644 --- a/aerosandbox/numpy/surrogate_model_tools.py +++ b/aerosandbox/numpy/surrogate_model_tools.py @@ -4,8 +4,8 @@ def softmax( *args: Union[float, _np.ndarray], + softness: float = 1., hardness: float = None, - softness: float = None, ) -> Union[float, _np.ndarray]: """ An element-wise softmax between two or more arrays. Also referred to as the logsumexp() function. @@ -35,16 +35,12 @@ def softmax( Soft maximum of the supplied values. """ ### Set defaults for hardness/softness + if not (hardness is None) ^ (softness is None): + raise ValueError("You must provide exactly one of `hardness` or `softness`.") if hardness is not None: - if softness is not None: - raise ValueError("You can't specify both `hardness` and `softness`.") - else: - if softness is not None: - hardness = 1 / softness - else: - hardness = 1.0 + softness = 1 / hardness - if _np.any(hardness <= 0): + if _np.any(softness <= 0): if softness is not None: raise ValueError("The value of `softness` must be positive.") else: @@ -54,8 +50,8 @@ def softmax( raise ValueError("You must call softmax with the value of two or more arrays that you'd like to take the " "element-wise softmax of.") - ### Scale the args by hardness - args = [arg * hardness for arg in args] + ### Scale the args by softness + args = [arg / softness for arg in args] ### Find the element-wise max and min of the arrays: min = args[0] @@ -68,14 +64,14 @@ def softmax( [_np.exp(_np.maximum(array - max, -500)) for array in args] ) ) - out = out / hardness + out = out * softness return out def softmin( *args: Union[float, _np.ndarray], + softness: float = 1., hardness: float = None, - softness: float = None, ) -> Union[float, _np.ndarray]: """ An element-wise softmin between two or more arrays. Related to the logsumexp() function. @@ -106,8 +102,8 @@ def softmin( """ return -softmax( *[-arg for arg in args], + softness=softness, hardness=hardness, - softness=softness )