Skip to content

Commit

Permalink
Merge pull request #106 from dmarx/dev
Browse files Browse the repository at this point in the history
draft easings interface
  • Loading branch information
dmarx authored Dec 14, 2023
2 parents 25b68fb + 914d974 commit c18ce25
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 28 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,11 @@ jobs:
pip install .[dev]
- name: Test with pytest
run: |
pytest
pytest --ignore-glob='*torch*' --ignore-glob='*numpy*'
# TODO: move this to a separate parallel job matrix
- name: Install torch and numpy
run: |
pip install torch numpy
- name: run tests for optional numpy and pytorch support
run: |
pytest -k "numpy or torch"
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ __pycache__
Pipfile
sandbox.ipynb
Pipfile.lock
.DS_Store
.DS_Store
build
*.ipynb_checkpoints
*.ipynb
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
'omegaconf',
'matplotlib',
#############
'numpy',
'torch'
#'numpy',
#'torch'
],
extras_require={
'dev': [
Expand Down
73 changes: 56 additions & 17 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,36 @@
from sortedcontainers import SortedDict
from typing import Tuple, Optional, Union, Dict, Callable

import numpy as np
import torch

from .interpolation import (
bisect_left_keyframe,
INTERPOLATORS,
EASINGS,
)
from .utils import id_generator, DictValuesArithmeticFriendly

def is_torch_tensor(obj):
try:
import torch
return isinstance(obj, torch.Tensor)
except ImportError:
pass
return False

def is_numpy_ndarray(obj):
try:
import numpy as np
return isinstance(obj, np.ndarray)
except ImportError:
pass
return False

def numpy_array_equal(a,b):
import numpy as np
return np.array_equal(a,b)

def torch_isequal(a,b):
import torch
return torch.equal(a,b)

# workhorse of Curve.__init__, should probably attach it as an instance method on Curve
def ensure_sorteddict_of_keyframes(
Expand All @@ -32,7 +53,8 @@ def ensure_sorteddict_of_keyframes(
sorteddict = curve
elif isinstance(curve, dict):
sorteddict = SortedDict(curve)
elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
#elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
elif (isinstance(curve, Number) or is_numpy_ndarray(curve) or is_torch_tensor(curve)):
sorteddict = SortedDict({0:Keyframe(t=0,value=curve, interpolation_method=default_interpolation, interpolator_arguments=default_interpolator_args)})
elif (isinstance(curve, list) or isinstance(curve, tuple)):
d_ = {}
Expand Down Expand Up @@ -81,7 +103,8 @@ def ensure_sorteddict_of_keyframes(
implied_interpolation = kf.interpolation_method
implied_interpolator_args = kf.interpolator_arguments
d_[k] = kf
elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
#elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
elif (isinstance(v, Number) or is_numpy_ndarray(v) or is_torch_tensor(v)):
d_[k] = Keyframe(t=k,value=v, interpolation_method=implied_interpolation, interpolator_arguments=implied_interpolator_args)
else:
raise NotImplementedError
Expand All @@ -98,13 +121,18 @@ def __init__(
value,
interpolation_method:Optional[Union[str,Callable]]=None,
interpolator_arguments=None,
label=None,
):
self.t=t
self.label = label
#self.value=value
### <chatgpt>
if isinstance(value, np.ndarray):
self.value = np.array(value) # Ensure a copy of the array is stored
elif isinstance(value, torch.Tensor):
#if isinstance(value, np.ndarray):
if is_numpy_ndarray(value):
#self.value = np.array(value) # Ensure a copy of the array is stored
self.value = deepcopy(value)
#elif isinstance(value, torch.Tensor):
elif is_torch_tensor(value):
self.value = value.clone() # Ensure a clone of the tensor is stored
else:
self.value = value
Expand All @@ -122,13 +150,18 @@ def interpolator_arguments(self):

def __eq__(self, other) -> bool:
### <chatgpt>
if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
if isinstance(self.value, np.ndarray):
return np.array_equal(self.value, np.array(other))
else:
return torch.equal(self.value, torch.tensor(other))
#if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
# if isinstance(self.value, np.ndarray):
# return np.array_equal(self.value, np.array(other))
# else:
# return torch.equal(self.value, torch.tensor(other))
### </chatgpt>
return self.value == other
if is_numpy_ndarray(self.value):
return numpy_array_equal(self.value, other)
elif is_torch_tensor(self.value):
return torch_isequal(self.value, other)
else:
return self.value == other
def __repr__(self) -> str:
#d = f"Keyframe(t={self.t}, value={self.value}, interpolation_method='{self.interpolation_method}')"
d = self.to_dict()
Expand All @@ -137,11 +170,15 @@ def _to_dict(self, *args, **kwargs) -> dict:
d = {'t':self.t, 'value':self.value, 'interpolation_method':self.interpolation_method}
if self.interpolator_arguments:
d['interpolator_arguments'] = self.interpolator_arguments
if self.label is not None:
d['label'] = self.label
### <chatgpt>
# Ensure the representation of numpy arrays and tensors are handled correctly
if isinstance(self.value, np.ndarray):
#if isinstance(self.value, np.ndarray):
if is_numpy_ndarray(self.value):
d['value'] = self.value.tolist()
elif isinstance(self.value, torch.Tensor):
#elif isinstance(self.value, torch.Tensor):
elif is_torch_tensor(self.value):
d['value'] = self.value.numpy().tolist()
else:
d['value'] = self.value
Expand Down Expand Up @@ -373,7 +410,9 @@ def __getitem__(self, k:Number) -> Number:
interp = left_value.interpolation_method

if (interp is None) or isinstance(interp, str):
f = INTERPOLATORS.get(interp)
f = EASINGS.get(interp)
if f is None:
f = INTERPOLATORS.get(interp)
if f is None:
raise ValueError(f"Unsupported interpolation method: {interp}")
elif isinstance(interp, Callable):
Expand Down
23 changes: 19 additions & 4 deletions src/keyframed/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import math
from numbers import Number
from typing import Callable
import numpy as np
import torch
#import numpy as np
#import torch
from functools import partial

def bisect_left_keyframe(k: Number, curve:'Curve', *args, **kargs) -> 'Keyframe':
"""
Expand Down Expand Up @@ -90,8 +91,11 @@ def exp_decay(t, curve, decay_rate):
return v0 * math.exp(-td * decay_rate)

def sine_wave(t, curve, wavelength=None, frequency=None, phase=0, amplitude=1):
if (wavelength is None) and (frequency is not None):
wavelength = 1/frequency
if (wavelength is None):
if (frequency is not None):
wavelength = 1/frequency
else:
wavelength = 4 # interpolate from 0 to pi/2
return amplitude * math.sin(2*math.pi*t / wavelength + phase)

INTERPOLATORS={
Expand All @@ -104,6 +108,17 @@ def sine_wave(t, curve, wavelength=None, frequency=None, phase=0, amplitude=1):
'sine_wave':sine_wave,
}

EASINGS={
None:bisect_left_value,
'previous':bisect_left_value,
'next':bisect_right_value,
'linear':partial(eased_lerp, ease=lambda t: t),
'sin':partial(eased_lerp, ease=lambda t: math.sin(t * math.pi / 2)),
'sin^2':partial(eased_lerp, ease=lambda t: math.sin(t * math.pi / 2)**2),

}


def register_interpolation_method(name:str, f:Callable):
"""
Adds a new interpolation method to the INTERPOLATORS registry.
Expand Down
19 changes: 16 additions & 3 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
import math

from keyframed import Curve, Keyframe

from keyframed.interpolation import (
register_interpolation_method,
bisect_left_value,
INTERPOLATORS
INTERPOLATORS,
EASINGS
)


Expand Down Expand Up @@ -117,4 +118,16 @@ def test_curve_w_kf_specified_interpolator():
c1 = Curve({0:1, 5:Keyframe(t=5,value=1,interpolation_method='linear'), 9:5})
assert c1[7] == 3

###############
###############

def test_sin():
c1 = Curve({0:0, 2:1}, default_interpolation=EASINGS['sin'])
assert c1[1] == math.sin(math.pi/4)

def test_sin():
c1 = Curve({0:0, 2:2}, default_interpolation=EASINGS['sin'])
assert c1[1] == 2*math.sin(math.pi/4)

def test_sin2():
c1 = Curve({0:0, 2:1}, default_interpolation=EASINGS['sin^2'])
assert c1[1] == math.sin(math.pi/4)**2

0 comments on commit c18ce25

Please sign in to comment.