Skip to content

Commit

Permalink
refactor softness and hardness on softmax and softmin methods - s…
Browse files Browse the repository at this point in the history
…houldn't break API as these have always needed to be kwargs
  • Loading branch information
peterdsharpe committed Mar 21, 2024
1 parent 749d96c commit 0e12d55
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions aerosandbox/numpy/surrogate_model_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

def softmax(
*args: Union[float, _np.ndarray],
softness: float = 1.,
hardness: float = None,
softness: float = None,
) -> Union[float, _np.ndarray]:
"""
An element-wise softmax between two or more arrays. Also referred to as the logsumexp() function.
Expand Down Expand Up @@ -35,16 +35,12 @@ def softmax(
Soft maximum of the supplied values.
"""
### Set defaults for hardness/softness
if not (hardness is None) ^ (softness is None):
raise ValueError("You must provide exactly one of `hardness` or `softness`.")
if hardness is not None:
if softness is not None:
raise ValueError("You can't specify both `hardness` and `softness`.")
else:
if softness is not None:
hardness = 1 / softness
else:
hardness = 1.0
softness = 1 / hardness

if _np.any(hardness <= 0):
if _np.any(softness <= 0):
if softness is not None:
raise ValueError("The value of `softness` must be positive.")
else:
Expand All @@ -54,8 +50,8 @@ def softmax(
raise ValueError("You must call softmax with the value of two or more arrays that you'd like to take the "
"element-wise softmax of.")

### Scale the args by hardness
args = [arg * hardness for arg in args]
### Scale the args by softness
args = [arg / softness for arg in args]

### Find the element-wise max and min of the arrays:
min = args[0]
Expand All @@ -68,14 +64,14 @@ def softmax(
[_np.exp(_np.maximum(array - max, -500)) for array in args]
)
)
out = out / hardness
out = out * softness
return out


def softmin(
*args: Union[float, _np.ndarray],
softness: float = 1.,
hardness: float = None,
softness: float = None,
) -> Union[float, _np.ndarray]:
"""
An element-wise softmin between two or more arrays. Related to the logsumexp() function.
Expand Down Expand Up @@ -106,8 +102,8 @@ def softmin(
"""
return -softmax(
*[-arg for arg in args],
softness=softness,
hardness=hardness,
softness=softness
)


Expand Down

0 comments on commit 0e12d55

Please sign in to comment.