Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve usage of symplot.SliderKwargs #109

Merged
merged 3 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/usage/dynamics/k-matrix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@
" # Concatenate flipped domain for reverse animation\n",
" domain = np.linspace(0.5, 3.0, 50)\n",
" domain = np.concatenate((domain, np.flip(domain[1:])))\n",
" sliders._SliderKwargs__sliders[\"m1\"] = domain"
" sliders._sliders[\"m1\"] = domain"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@
" # Concatenate flipped domain for reverse animation\n",
" domain = np.linspace(0.8, 2.2, 100)\n",
" domain = np.concatenate((domain, np.flip(domain[1:])))\n",
" sliders._SliderKwargs__sliders[\"m_f(0)(980)\"] = domain # dirty hack"
" sliders._sliders[\"m_f(0)(980)\"] = domain"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/symplot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
" # Concatenate flipped domain for reverse animation\n",
" domain = np.linspace(-1, 1, 50)\n",
" domain = np.concatenate((domain, np.flip(domain[1:])))\n",
" sliders._SliderKwargs__sliders[\"a\"] = domain # dirty hack"
" sliders._sliders[\"a\"] = domain"
]
},
{
Expand Down
54 changes: 31 additions & 23 deletions src/symplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Slider = Union[FloatSlider, IntSlider]
RangeDefinition = Union[
Tuple[float, float],
Tuple[float, float, int],
Tuple[float, float, Union[float, int]],
]


Expand All @@ -52,11 +52,11 @@ def __init__(
arg_to_symbol: Mapping[str, str],
) -> None:
self._verify_arguments(sliders, arg_to_symbol)
self.__sliders = dict(sliders)
self.__arg_to_symbol = {
self._sliders = dict(sliders)
self._arg_to_symbol = {
arg: symbol
for arg, symbol in arg_to_symbol.items()
if symbol in self.__sliders
if symbol in self._sliders
}

@staticmethod
Expand Down Expand Up @@ -90,30 +90,30 @@ def __getitem__(self, key: Union[str, sp.Symbol]) -> Slider:
"""Get slider by symbol, symbol name, or argument name."""
if isinstance(key, sp.Symbol):
key = key.name
if key in self.__arg_to_symbol:
slider_name = self.__arg_to_symbol[key]
if key in self._arg_to_symbol:
slider_name = self._arg_to_symbol[key]
else:
slider_name = key
if slider_name not in self.__sliders:
if slider_name not in self._sliders:
raise KeyError(f'"{key}" is neither an argument nor a symbol name')
return self.__sliders[slider_name]
return self._sliders[slider_name]

def __iter__(self) -> Iterator[str]:
"""Iterate over the arguments of the `.LambdifiedExpression`.

This is useful for unpacking an instance of `SliderKwargs` as
:term:`kwargs <python:keyword argument>`.
"""
return self.__arg_to_symbol.__iter__()
return self._arg_to_symbol.__iter__()

def __len__(self) -> int:
return len(self.__sliders)
return len(self._sliders)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"sliders={self.__sliders}, "
f"arg_to_symbol={self.__arg_to_symbol})"
f"sliders={self._sliders}, "
f"arg_to_symbol={self._arg_to_symbol})"
)

def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
Expand All @@ -124,11 +124,11 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
with p.group(indent=2, open=f"{class_name}("):
p.breakable()
p.text("sliders=")
p.pretty(self.__sliders)
p.pretty(self._sliders)
p.text(",")
p.breakable()
p.text("arg_to_symbol=")
p.pretty(self.__arg_to_symbol)
p.pretty(self._arg_to_symbol)
p.text(",")
p.breakable()
p.text(")")
Expand All @@ -151,10 +151,15 @@ def set_values(self, *args: Dict[str, float], **kwargs: float) -> None:
)
continue

def set_ranges(
def set_ranges( # noqa: R701
self, *args: Dict[str, RangeDefinition], **kwargs: RangeDefinition
) -> None:
"""Set min, max and (optionally) the nr of steps for each slider."""
"""Set min, max and (optionally) the nr of steps for each slider.

.. tip::
:code:`n_steps` becomes the step **size** if its value is
`float`.
"""
range_definitions = _merge_args_kwargs(*args, **kwargs)
for slider_name, range_def in range_definitions.items():
if not isinstance(range_def, tuple):
Expand All @@ -169,7 +174,10 @@ def set_ranges(
min_, max_, n_steps = range_def # type: ignore
if n_steps <= 0:
raise ValueError("Number of steps has to be positive")
step_size = (max_ - min_) / n_steps
if isinstance(n_steps, float):
step_size = n_steps
else:
step_size = (max_ - min_) / n_steps
else:
raise ValueError(
f'Range definition {range_def} for slider "{slider_name}"'
Expand Down Expand Up @@ -224,16 +232,16 @@ def prepare_sliders(
SliderKwargs(...)
"""
slider_symbols = _extract_slider_symbols(expression, plot_symbol)
sliders_mapping = {
symbol.name: create_slider(symbol) for symbol in slider_symbols
}
lambdified_expression = sp.lambdify(
(plot_symbol, *slider_symbols),
expression,
"numpy",
modules="numpy",
)
symbols_names = list(map(lambda s: s.name, (plot_symbol, *slider_symbols)))
arg_names = list(inspect.signature(lambdified_expression).parameters)
sliders_mapping = {
symbol.name: create_slider(symbol) for symbol in slider_symbols
}
symbols_names = map(lambda s: s.name, (plot_symbol, *slider_symbols))
arg_names = inspect.signature(lambdified_expression).parameters
arg_to_symbol = dict(zip(arg_names, symbols_names))
sliders = SliderKwargs(sliders_mapping, arg_to_symbol)
return lambdified_expression, sliders
Expand Down
14 changes: 12 additions & 2 deletions tests/symplot/test_symplot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# pylint: disable=eval-used, no-self-use, protected-access, redefined-outer-name
import logging
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Pattern, Type, no_type_check
from typing import (
Any,
Callable,
Dict,
Optional,
Pattern,
Type,
Union,
no_type_check,
)

import pytest
import sympy as sp
Expand Down Expand Up @@ -93,14 +102,15 @@ def test_repr(
("n", 10, 20, 20, 1),
("alpha", -1.5, 0.7, None, 0.1),
("alpha", 8, 10, 4, 0.5),
("alpha", 8, 10, 0.5, 0.5),
],
)
def test_set_ranges( # pylint: disable=too-many-arguments
self,
slider_name: str,
min_: float,
max_: float,
n_steps: Optional[int],
n_steps: Optional[Union[float, int]],
step_size: float,
slider_kwargs: SliderKwargs,
) -> None:
Expand Down