Skip to content

Commit

Permalink
Merge pull request #89 from dmarx/dev
Browse files Browse the repository at this point in the history
fix compositional pgroup arithmetic
  • Loading branch information
dmarx committed Mar 4, 2023
2 parents 2b916a2 + aa2face commit 64b76c2
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 38 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@

setup(
name='keyframed',
version='0.3.10',
version='0.3.11',
author='David Marx',
long_description=README,
long_description_content_type='text/markdown',
short_description=st,
install_requires=[
'sortedcontainers',
'omegaconf',
'matplotlib',
],
extras_require={
'dev': [
Expand Down
82 changes: 60 additions & 22 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def __neg__(self) -> 'CurveBase':
return self * (-1)

def __eq__(self, other) -> bool:
return self.to_dict() == other.to_dict()
return self.to_dict(simplify=True, ignore_labels=True) == other.to_dict(simplify=True, ignore_labels=True)
@abstractmethod
def to_dict(simplify=False, for_yaml=False):
def to_dict(simplify=False, for_yaml=False, ignore_labels=False):
raise NotImplementedError

class Curve(CurveBase):
Expand Down Expand Up @@ -385,12 +385,9 @@ def __str__(self) -> str:
return f"Curve({d_}"

def __add__(self, other) -> CurveBase:
if isinstance(other, CurveBase):
return self.__add_curves__(other)
outv = self.copy()
for k in self.keyframes:
outv[k]= outv[k] + other
return outv
if not isinstance(other, CurveBase):
other = Curve(other)
return self.__add_curves__(other)

def __add_curves__(self, other) -> 'Composition':
if isinstance(other, ParameterGroup):
Expand Down Expand Up @@ -420,7 +417,7 @@ def __mul_curves__(self, other) -> 'Composition':
def from_function(cls, f:Callable) -> CurveBase:
return cls({0:f(0)}, default_interpolation=lambda k, _: f(k))

def to_dict(self, simplify=False, for_yaml=False):
def to_dict(self, simplify=False, for_yaml=False, ignore_labels=False):

if for_yaml:
d_curve = tuple([kf._to_tuple(simplify=simplify) for k, kf in self._data.items()])
Expand Down Expand Up @@ -478,6 +475,9 @@ def to_dict(self, simplify=False, for_yaml=False):
label=self.label,
)

if ignore_labels and 'label' in outv:
outv.pop('label')

return outv

def append(self, other):
Expand Down Expand Up @@ -545,6 +545,7 @@ def weight(self):
# defining this as a property so we can override the label to
# always match the label of the associated ParameterGroup
self._weight.label = f"{self.label}_WEIGHT"
self._weight._using_default_label = True
return self._weight

def __get_slice(self, k) -> 'ParameterGroup':
Expand Down Expand Up @@ -589,6 +590,9 @@ def __rtruediv__(self, other) -> 'ParameterGroup':
outv.parameters[k] = other / v
return outv

def __eq__(self, other) -> bool:
return self.to_dict(simplify=True, ignore_labels=True)['parameters'] == other.to_dict(simplify=True, ignore_labels=True)['parameters']

@property
def duration(self) -> Number:
return max(curve.duration for curve in self.parameters.values())
Expand Down Expand Up @@ -616,16 +620,19 @@ def values(self) -> list:

def random_label(self) -> str:
return f"pgroup({','.join([c.label for c in self.parameters.values()])})"
def to_dict(self, simplify=False, for_yaml=False):
params = {k:v.to_dict(simplify=simplify, for_yaml=for_yaml) for k,v in self.parameters.items()}
weight = self.weight.to_dict(simplify=simplify, for_yaml=for_yaml)
def to_dict(self, simplify=False, for_yaml=False, ignore_labels=False):
params = {k:v.to_dict(simplify=simplify, for_yaml=for_yaml, ignore_labels=ignore_labels) for k,v in self.parameters.items()}
weight = self.weight.to_dict(simplify=simplify, for_yaml=for_yaml, ignore_labels=ignore_labels)

if not simplify:
return dict(
parameters=params,
weight=weight,
label=self.label,
)
outv= dict(
parameters=params,
weight=weight,
#label=self.label,
)
if not ignore_labels:
outv['label'] = self.label
return outv
else:
for k in list(params.keys()):
if 'label' in params[k]:
Expand All @@ -635,10 +642,12 @@ def to_dict(self, simplify=False, for_yaml=False):
wt2 = deepcopy(weight)
if 'label' in wt2:
wt2.pop('label')
if wt2 != Curve(1).to_dict(simplify=simplify, for_yaml=for_yaml):
if wt2 != Curve(1).to_dict(simplify=simplify, for_yaml=for_yaml, ignore_labels=ignore_labels):
outv['weight'] = wt2 #weight
if not hasattr(self, '_using_default_label'):
if not hasattr(self, '_using_default_label') and not ignore_labels:
outv['label'] = self.label
if ignore_labels and 'label' in outv:
outv.pop('label')
return outv

REDUCTIONS = {
Expand Down Expand Up @@ -700,7 +709,10 @@ def __getitem__(self, k) -> Union[Number,dict]:
outv = reduce(f, vals)
if self.reduction in ('avg', 'average', 'mean'):
outv = outv * (1/ len(vals))
outv = outv * self.weight[k]
# TO DO: this only fixes equality test for unmodified pgroup weight.
# if pgroup weight is anything non-standard, equality test will fail with isinstance(k, slice)
if self.weight != Curve({0:1}):
outv = outv * self.weight[k]
return outv

def random_label(self, d=None) ->str:
Expand All @@ -709,12 +721,30 @@ def random_label(self, d=None) ->str:
basename = ', '.join([str(keyname) for keyname in d.keys()])
return f"{self.reduction}({basename})_{id_generator()}"

def __sub__(self, other) -> 'Composition':
# if other is pgroup, delegate control of arithmetic ops to it
#if isinstance(other, ParameterGroup) and not isinstance(other, Composition):
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented
return super().__sub__(other)

def __radd__(self, other) -> 'Composition':
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented
return super().__radd__(other)

def __add__(self, other) -> 'Composition':
# if other is pgroup, delegate control of arithmetic ops to it
#if isinstance(other, ParameterGroup) and not isinstance(other, Composition):
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented

from loguru import logger
logger.debug((self.label, self))
logger.debug(other)
if not isinstance(other, CurveBase):
other = Curve(other)
logger.debug(other.label)

pg_copy = self.copy()
if self.reduction in ('sum', 'add'):
Expand All @@ -725,6 +755,8 @@ def __add__(self, other) -> 'Composition':
return Composition(parameters=d, weight=pg_copy.weight, reduction='sum')

def __mul__(self, other) -> 'ParameterGroup':
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented
if not isinstance(other, CurveBase):
other = Curve(other)

Expand All @@ -737,6 +769,8 @@ def __mul__(self, other) -> 'ParameterGroup':
return Composition(parameters=d, reduction='prod')

def __truediv__(self, other) -> 'Composition':
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented
if not isinstance(other, CurveBase):
other = Curve(other)

Expand All @@ -745,6 +779,8 @@ def __truediv__(self, other) -> 'Composition':
return Composition(parameters=d, reduction='truediv')

def __rtruediv__(self, other) -> 'Composition':
if isinstance(other, ParameterGroup) and not isinstance(other, type(self)):
return NotImplemented
if not isinstance(other, CurveBase):
other = Curve(other)

Expand Down Expand Up @@ -793,7 +829,9 @@ def plot(self, n:int=None, xs:list=None, eps:float=1e-9, *args, **kargs):
kfx = self.keyframes
kfy = [self[x][label] for x in kfx]
plt.scatter(kfx, kfy)
def to_dict(self, simplify=False, for_yaml=False):
outv = super().to_dict(simplify=simplify, for_yaml=for_yaml)
def to_dict(self, simplify=False, for_yaml=False, ignore_labels=False):
outv = super().to_dict(simplify=simplify, for_yaml=for_yaml, ignore_labels=ignore_labels)
outv['reduction'] = self.reduction
if ignore_labels and 'label' in outv:
outv.pop('label')
return outv
4 changes: 2 additions & 2 deletions src/keyframed/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def from_dict(d:dict):

raise NotImplementedError

def to_yaml(obj:CurveBase, simplify=True):
d = obj.to_dict(simplify=simplify, for_yaml=True)
def to_yaml(obj:CurveBase, simplify=True, ignore_labels=False):
d = obj.to_dict(simplify=simplify, for_yaml=True, ignore_labels=ignore_labels)
cfg = OmegaConf.create(d)
return OmegaConf.to_yaml(cfg)

Expand Down
67 changes: 67 additions & 0 deletions tests/test_bug_20230303.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from keyframed import SinusoidalCurve, ParameterGroup, Curve, Composition

import math
import matplotlib.pyplot as plt

import pytest

def test_bug0():


low, high = 0, 0.3
step1 = 50
step2 = 2 * step1

curves = ParameterGroup((
#SmoothCurve({0:low, step1:high}, bounce=True),
SinusoidalCurve(wavelength=step1*2, phase=3*math.pi/2) ,#+ Curve(high/2),
#SmoothCurve({0:high, step1:low}, bounce=True),
SinusoidalCurve(wavelength=step1*2, phase=math.pi/2) ,#+ Curve(high/2),
#SmoothCurve({0:low, step2:high}, bounce=True),
SinusoidalCurve(wavelength=step2*2, phase=3*math.pi/2) ,#+ Curve(high/2),
#SmoothCurve({0:high, step2:low}, bounce=True),
SinusoidalCurve(wavelength=step2*2, phase=math.pi/2) ,#+ Curve(high/2),
#SinusoidalCurve(wavelength=step1*4, amplitude=high/2) + high/2

#SinusoidalCurve(wavelength=step1, phase=math.pi),
#SinusoidalCurve(wavelength=step2),
#SinusoidalCurve(wavelength=step2, phase=math.pi),
))


#curves.plot(n=1000)
#plt.show()

# Define another curve implicitly, extrapolating from a function
#fancy = Curve.from_function(lambda k: high + math.sin(2*k/(step1+step2)))
#fancy = SinusoidalCurve(wavelength=(step2+step1)/math.pi) #+ .001 #Curve(1) #Curve(high),

fancy = SinusoidalCurve(wavelength=(step2+step1)/math.pi) + Curve(high) # breaks
# This does it too
#fancy = SinusoidalCurve(wavelength=(step2+step1)*math.pi) + Curve({0:high})

#fancy.plot(1000)
#fancy = SinusoidalCurve(wavelength=(step2+step1)*math.pi) + high # the addition here just modifies the first keyframe
#fancy = SinusoidalCurve(wavelength=(step2+step1)*math.pi) + Curve({0:high}) # maybe the issue here is conflicting keyframes?

#fancy = SinusoidalCurve(wavelength=(step2+step1)*math.pi)

#fancy.plot(1000)
#plt.show()

# arithmetic on curves
curves_plus_fancy = curves + fancy +1 + Curve(high) #+ Curve(high/2)
curves_summed_by_frame = Composition(curves_plus_fancy, reduction='sum')
really_fancy = curves_plus_fancy / curves_summed_by_frame

# isolate a single channel
## This breaks after modifying the implementation for "fancy"
channel_name = list(really_fancy[0].keys())[-1]

# this is not the desired behavior. remove 'with' statement to catch the bug.
#with pytest.raises(NotImplementedError):
print(channel_name)
# print(really_fancy.to_dict(simplify=True))
print(really_fancy[0]) # weird name reuse here
print(really_fancy[0][channel_name])
red_channel = Curve.from_function(lambda k: really_fancy[k][channel_name])
57 changes: 44 additions & 13 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def test_compositional_pgroup_from_yamldict():
curves2 = curves + 1
d = curves2.to_dict(simplify=False, for_yaml=True)
curves3 = from_dict(d)
print(curves2.to_dict(simplify=True))
print(curves3.to_dict(simplify=True))
print(curves2.to_dict(simplify=True) == curves3.to_dict(simplify=True))
print(list(curves2.to_dict(simplify=True).keys()))
print(list(curves3.to_dict(simplify=True).keys()))
print(curves2.to_dict(simplify=True)['parameters'] == curves3.to_dict(simplify=True)['parameters'])
assert curves2 == curves3


Expand All @@ -91,20 +97,45 @@ def test_curve_sum_to_yaml():
c0 = Curve({1:1}, label='foo', default_interpolation='linear')
c1 = c0 + 1
c2 = 1 + c0
txt1 = to_yaml(c1, simplify=False)
txt2 = to_yaml(c2, simplify=False)
#txt1 = to_yaml(c1, simplify=False)
#txt2 = to_yaml(c2, simplify=False)


i=0
for k,v in list(c1.parameters.items()):
if '_' in k:
temp = c1.parameters.pop(k)
c1.parameters[f"_{i}"] = temp
i+=1

i=0
for k,v in list(c2.parameters.items()):
if '_' in k:
temp = c2.parameters.pop(k)
c2.parameters[f"_{i}"] = temp
i+=1


txt1 = to_yaml(c1, simplify=True, ignore_labels=True).strip()
txt2 = to_yaml(c2, simplify=True, ignore_labels=True).strip()
print(txt1)
print(txt2)
assert txt1 == txt2
assert txt1.strip() == """curve:
- - 0
- 1
- linear
- - 1
- 2
- linear
loop: false
bounce: false
duration: 1
label: foo"""
assert txt1 == """
parameters:
foo:
curve:
- - 0
- 0
- linear
- - 1
- 1
_0:
curve:
- - 0
- 1
reduction: add
""".strip()

def test_curve_prod_to_yaml():
c0 = Curve({1:1}, label='foo', default_interpolation='linear')
Expand Down
5 changes: 5 additions & 0 deletions tests/test_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def test_simple_comp_slicing():
c0=Curve(1, label='foo')
c1 = c0 + 1
c2 = c1[:]
print(c1.to_dict(simplify=True, ignore_labels=True))
print(c2.to_dict(simplify=True, ignore_labels=True))
# WAAAT
#{'parameters': {'foo': {'curve': {0: {'value': 1}}}, 'curve_CRMVJ4': {'curve': {0: {'value': 1}}}}, 'reduction': 'add'}
#{'parameters': {'foo+curve_CRMVJ4': {'parameters': {'foo': {'curve': {0: {'value': 1}}}, 'curve_CRMVJ4': {'curve': {0: {'value': 1}}}}, 'reduction': 'add'}, 'foo+curve_CRMVJ4_WEIGHT': {'curve': {0: {'value': 1}}}}, 'reduction': 'prod'}
assert c1 == c2

def test_comp_interpolated_slicing_left():
Expand Down

0 comments on commit 64b76c2

Please sign in to comment.