Skip to content

Commit

Permalink
Merge pull request #70 from dmarx/bounce
Browse files Browse the repository at this point in the history
add "bounce" looping mode
  • Loading branch information
dmarx authored Feb 14, 2023
2 parents 60d42de + 5a15e8c commit 77964b0
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='keyframed',
version='0.3.5',
version='0.3.6',
author='David Marx',
long_description=README,
long_description_content_type='text/markdown',
Expand Down
30 changes: 26 additions & 4 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,20 @@ def duration(self) -> Number:
pass

@abstractmethod
def __getitem__(self) -> Number:
def __getitem__(self, k) -> Number:
pass

def _adjust_k_for_looping(self, k:Number) -> Number:
n = (self.duration + 1)
if self.loop and k >= max(self.keyframes):
k %= n
elif self.bounce:
n2 = 2*(n-1)
k %= n2
if k >= n:
k = n2 - k
return k

def plot(self, n:int=None, xs:list=None, eps:float=1e-9, *args, **kargs):
"""
Arguments
Expand Down Expand Up @@ -196,6 +207,7 @@ def __init__(self,
] = ((0,0),),
default_interpolation='previous',
loop: bool = False,
bounce: bool = False,
duration:Optional[float]=None,
label:str=None,
):
Expand All @@ -214,6 +226,7 @@ def __init__(self,

self.default_interpolation=default_interpolation
self.loop=loop
self.bounce=bounce
self._duration=duration
if label is None:
label = self.random_label()
Expand Down Expand Up @@ -268,7 +281,7 @@ def __get_slice(self, k:slice):
#loop = self.loop if end# to do: revisit the logic here
loop = False # let's just keep it like this for simplicity. if someone wants a slice output to loop, they can be explicit
return Curve(curve=d, loop=loop, duration=end)

def __getitem__(self, k:Number) -> Number:
"""
Under the hood, the values in our SortedDict should all be Keyframe objects,
Expand All @@ -277,8 +290,8 @@ def __getitem__(self, k:Number) -> Number:
if isinstance(k, slice):
return self.__get_slice(k)

if self.loop and k >= max(self.keyframes):
k %= (self.duration + 1)
k = self._adjust_k_for_looping(k)

if k in self._data.keys():
outv = self._data[k]
if isinstance(outv, Keyframe):
Expand Down Expand Up @@ -416,7 +429,10 @@ def __init__(
parameters:Union[Dict[str, Curve],'ParameterGroup', list, tuple],
weight:Optional[Union[Curve,Number]]=1,
label=None,
loop: bool = False,
bounce: bool = False,
):
self.loop, self.bounce = loop, bounce
if isinstance(parameters, list) or isinstance(parameters, tuple):
d = {}
for curve in parameters:
Expand Down Expand Up @@ -464,6 +480,7 @@ def __get_slice(self, k) -> 'ParameterGroup':
def __getitem__(self, k) -> dict:
if isinstance(k, slice):
return self.__get_slice(k)
k = self._adjust_k_for_looping(k)
wt = self.weight[k]
d = {name:param[k]*wt for name, param in self.parameters.items() }
return DictValuesArithmeticFriendly(d)
Expand Down Expand Up @@ -587,7 +604,10 @@ def __init__(
weight:Optional[Union[Curve,Number]]=1,
reduction:str=None,
label:str=None,
loop:bool=False,
bounce:bool=False,
):
self.loop, self.bounce = loop, bounce
self.reduction = reduction
super().__init__(parameters=parameters, weight=weight, label=label)
# uh.... let's try this I guess?
Expand All @@ -596,6 +616,8 @@ def __init__(
self.label = self.random_label()

def __getitem__(self, k) -> Union[Number,dict]:

k = self._adjust_k_for_looping(k)
f = REDUCTIONS.get(self.reduction)

vals = [curve[k] for curve in self.parameters.values()]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,34 @@ def test_fancy_dict():
assert 2/d == {'a':-1/2,'b':0,'c':1/2}
d.pop('b')
assert 2/d == {'a':-2,'c':2}

###############


def test_comp_of_loops():
curve = Curve([(0,0),(2,2)], loop=True, default_interpolation='linear')
comp = curve *2
for i in range(10):
assert comp[i] == curve[i] * 2

# def test_comp_loop():
# curve = Curve([(0,0),(2,2)], default_interpolation='linear')
# curve_loop = curve.copy()
# curve_loop.loop = True
# pgroup = ParameterGroup({'p1': curve}, loop=True)
# for i in range(10):
# assert pgroup[i]['p1'] == curve_loop[i]

# def test_comp_bounce():
# curve = Curve([(0,0),(2,2)], default_interpolation='linear')
# curve_loop = curve.copy()
# curve_loop.bounce = True
# pgroup = ParameterGroup({'p1': curve}, bounce=True)
# for i in range(10):
# assert pgroup[i]['p1'] == curve_loop[i]

# def test_comp_of_bounces():
# curve = Curve([(0,0),(2,2)], bounce=True, default_interpolation='linear')
# pgroup = ParameterGroup({'p1': curve})
# for i in range(10):
# assert pgroup[i]['p1'] == curve[i]
10 changes: 10 additions & 0 deletions tests/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def test_curve_looping():
assert curve[15] == 0
assert curve[19] == 9

def test_curve_bounce():
curve = Curve(((0, 0), (9, 9)), bounce=True, default_interpolation='linear')
for i in range(20):
print(f"{i}:{curve[i]}")
assert curve[0] == 0
assert curve[8] == 8
assert curve[9] == 9
assert curve[10] == 8
assert curve[18] == 0
assert curve[20] == 2
#########################

# scavenged from test_callable_patterns.py
Expand Down
28 changes: 28 additions & 0 deletions tests/test_pgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,31 @@ def test_pgroup_nontrivial():
assert pgroup.weight[1] == 2
assert pgroup[0] == {'p1': 2, 'p2': 4}
assert pgroup[1] == {'p1': 2, 'p2': 4}

def test_pgroup_of_loops():
curve = Curve([(0,0),(2,2)], loop=True, default_interpolation='linear')
pgroup = ParameterGroup({'p1': curve})
for i in range(10):
assert pgroup[i]['p1'] == curve[i]

def test_pgroup_loop():
curve = Curve([(0,0),(2,2)], default_interpolation='linear')
curve_loop = curve.copy()
curve_loop.loop = True
pgroup = ParameterGroup({'p1': curve}, loop=True)
for i in range(10):
assert pgroup[i]['p1'] == curve_loop[i]

def test_pgroup_bounce():
curve = Curve([(0,0),(2,2)], default_interpolation='linear')
curve_loop = curve.copy()
curve_loop.bounce = True
pgroup = ParameterGroup({'p1': curve}, bounce=True)
for i in range(10):
assert pgroup[i]['p1'] == curve_loop[i]

def test_pgroup_of_bounces():
curve = Curve([(0,0),(2,2)], bounce=True, default_interpolation='linear')
pgroup = ParameterGroup({'p1': curve})
for i in range(10):
assert pgroup[i]['p1'] == curve[i]

0 comments on commit 77964b0

Please sign in to comment.