Skip to content

Commit

Permalink
Merge pull request #104 from dmarx/tensor_support
Browse files Browse the repository at this point in the history
Tensor support
  • Loading branch information
dmarx authored Dec 9, 2023
2 parents 57d214d + 66764cd commit 25b68fb
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
*.egg-info
Pipfile
sandbox.ipynb
Pipfile.lock
Pipfile.lock
.DS_Store
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
'sortedcontainers',
'omegaconf',
'matplotlib',
#############
'numpy',
'torch'
],
extras_require={
'dev': [
Expand Down
33 changes: 30 additions & 3 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from sortedcontainers import SortedDict
from typing import Tuple, Optional, Union, Dict, Callable

import numpy as np
import torch

from .interpolation import (
bisect_left_keyframe,
INTERPOLATORS,
Expand All @@ -29,7 +32,7 @@ def ensure_sorteddict_of_keyframes(
sorteddict = curve
elif isinstance(curve, dict):
sorteddict = SortedDict(curve)
elif isinstance(curve, Number):
elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
sorteddict = SortedDict({0:Keyframe(t=0,value=curve, interpolation_method=default_interpolation, interpolator_arguments=default_interpolator_args)})
elif (isinstance(curve, list) or isinstance(curve, tuple)):
d_ = {}
Expand Down Expand Up @@ -78,7 +81,7 @@ def ensure_sorteddict_of_keyframes(
implied_interpolation = kf.interpolation_method
implied_interpolator_args = kf.interpolator_arguments
d_[k] = kf
elif isinstance(v, Number):
elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
d_[k] = Keyframe(t=k,value=v, interpolation_method=implied_interpolation, interpolator_arguments=implied_interpolator_args)
else:
raise NotImplementedError
Expand All @@ -97,7 +100,15 @@ def __init__(
interpolator_arguments=None,
):
self.t=t
self.value=value
#self.value=value
### <chatgpt>
if isinstance(value, np.ndarray):
self.value = np.array(value) # Ensure a copy of the array is stored
elif isinstance(value, torch.Tensor):
self.value = value.clone() # Ensure a clone of the tensor is stored
else:
self.value = value
### </chatgpt>
self.interpolation_method=interpolation_method
if interpolator_arguments is None:
interpolator_arguments = {}
Expand All @@ -110,6 +121,13 @@ def interpolator_arguments(self):
return {}

def __eq__(self, other) -> bool:
### <chatgpt>
if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
if isinstance(self.value, np.ndarray):
return np.array_equal(self.value, np.array(other))
else:
return torch.equal(self.value, torch.tensor(other))
### </chatgpt>
return self.value == other
def __repr__(self) -> str:
#d = f"Keyframe(t={self.t}, value={self.value}, interpolation_method='{self.interpolation_method}')"
Expand All @@ -119,6 +137,15 @@ def _to_dict(self, *args, **kwargs) -> dict:
d = {'t':self.t, 'value':self.value, 'interpolation_method':self.interpolation_method}
if self.interpolator_arguments:
d['interpolator_arguments'] = self.interpolator_arguments
### <chatgpt>
# Ensure the representation of numpy arrays and tensors are handled correctly
if isinstance(self.value, np.ndarray):
d['value'] = self.value.tolist()
elif isinstance(self.value, torch.Tensor):
d['value'] = self.value.numpy().tolist()
else:
d['value'] = self.value
### </chatgpt>
return d
def _to_tuple(self, *args, **kwags):
if not self.interpolator_arguments:
Expand Down
15 changes: 14 additions & 1 deletion src/keyframed/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
from numbers import Number
from typing import Callable

import numpy as np
import torch

def bisect_left_keyframe(k: Number, curve:'Curve', *args, **kargs) -> 'Keyframe':
"""
Expand Down Expand Up @@ -65,6 +66,18 @@ def linear(k, curve, *args, **kargs):
right = bisect_right_keyframe(k, curve)
x0, x1 = left.t, right.t
y0, y1 = left.value, right.value

# ### <chatgpt>
# # Handle both NumPy arrays and PyTorch tensors
# if isinstance(y0, (np.ndarray, torch.Tensor)) and isinstance(y1, (np.ndarray, torch.Tensor)):
# d = x1 - x0
# t = (x1 - k) / d
# if isinstance(y0, np.ndarray):
# return t * y0 + (1 - t) * y1
# else:
# return t * y0 + (1 - t) * y1
# ### </chatgpt>

d = x1-x0
t = (x1-k)/d
outv = t*y0 + (1-t)*y1
Expand Down
51 changes: 51 additions & 0 deletions tests/test_numpy_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
from keyframed import Curve

def test_vector_interpolation_linear():
# Test linear interpolation with vectors
start_vec = np.array([0, 0, 0])
end_vec = np.array([10, 10, 10])
curve = Curve({0: start_vec, 10: end_vec}, default_interpolation='linear')

assert np.allclose(curve[5], np.array([5, 5, 5]))
assert np.allclose(curve[2.5], np.array([2.5, 2.5, 2.5]))
assert np.allclose(curve[7.5], np.array([7.5, 7.5, 7.5]))

# def test_vector_interpolation_custom():
# # Test custom interpolation function for vectors
# def custom_interp(t, t0, value0, t1, value1):
# # Example: simple linear interpolation
# return value0 + (value1 - value0) * (t - t0) / (t1 - t0)

# start_vec = np.array([1, 2, 3])
# end_vec = np.array([4, 5, 6])
# curve = Curve({0: start_vec, 10: end_vec})
# curve.set_interpolation(custom_interp)

# assert np.allclose(curve[5], np.array([2.5, 3.5, 4.5]))
# assert np.allclose(curve[2], np.array([1.3, 2.3, 3.3]))

def test_vector_keyframe_insertion():
# Test inserting vector keyframes
curve = Curve()
curve[0] = np.array([1, 2, 3])
curve[5] = np.array([4, 5, 6])

assert np.allclose(curve[0], np.array([1, 2, 3]))
assert np.allclose(curve[5], np.array([4, 5, 6]))

# def test_vector_interpolation_bounds():
# # Test vector interpolation at curve bounds
# curve = Curve({0: np.array([0, 0]), 10: np.array([10, 10])}, default_interpolation='linear')

# assert np.allclose(curve[-5], np.array([0, 0])) # Test extrapolation if your curve supports it
# assert np.allclose(curve[15], np.array([10, 10]))

# def test_vector_mixed_interpolation():
# # Test curves with mixed scalar and vector interpolation
# curve = Curve({0: 0, 5: np.array([5, 10, 15]), 10: 10}, default_interpolation='linear')

# assert curve[2.5] == 2.5 # Scalar interpolation
# assert np.allclose(curve[7.5], np.array([7.5, 12.5, 17.5])) # Vector interpolation

# Add more tests for edge cases, different vector lengths, and other interpolation methods as needed
11 changes: 11 additions & 0 deletions tests/test_torch_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
from keyframed import Curve


def test_matrix_interpolation():
start_matrix = torch.tensor([[0, 0], [0, 0]], dtype=torch.float32)
end_matrix = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)
curve = Curve({0: start_matrix, 10: end_matrix}, default_interpolation='linear')

expected_mid = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32)
assert torch.allclose(curve[5], expected_mid)
54 changes: 54 additions & 0 deletions tests/test_torch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import numpy as np
import torch
from keyframed import Curve, Keyframe

# Test linear interpolation with PyTorch tensors
def test_linear_interpolation_tensor():
start_tensor = torch.tensor([0, 0, 0], dtype=torch.float32)
end_tensor = torch.tensor([10, 10, 10], dtype=torch.float32)
curve = Curve({0: start_tensor, 10: end_tensor}, default_interpolation='linear')

assert torch.allclose(curve[5], torch.tensor([5, 5, 5], dtype=torch.float32))
assert torch.allclose(curve[2.5], torch.tensor([2.5, 2.5, 2.5], dtype=torch.float32))
assert torch.allclose(curve[7.5], torch.tensor([7.5, 7.5, 7.5], dtype=torch.float32))

# Test mixed NumPy array and PyTorch tensor interpolation
# def test_mixed_interpolation():
# start_array = np.array([1, 2, 3])
# end_tensor = torch.tensor([4, 5, 6])
# curve = Curve({0: start_array, 10: end_tensor}, default_interpolation='linear')

# expected_midpoint = torch.tensor([2.5, 3.5, 4.5])
# assert torch.allclose(curve[5], expected_midpoint)

# Test PyTorch tensor interpolation with custom interpolation method
# def test_custom_tensor_interpolation():
# def custom_interp(t, t0, value0, t1, value1):
# value0, value1 = torch.tensor(value0), torch.tensor(value1)
# return value0 + (value1 - value0) * (t - t0) / (t1 - t0)

# start_tensor = torch.tensor([1, 1, 1])
# end_tensor = torch.tensor([2, 2, 2])
# curve = Curve({0: start_tensor, 1: end_tensor})
# curve.set_interpolation(custom_interp)

# assert torch.allclose(curve[0.5], torch.tensor([1.5, 1.5, 1.5]))

# Test tensor keyframe insertion
def test_tensor_keyframe_insertion():
curve = Curve()
curve[0] = torch.tensor([1, 2, 3])
curve[5] = torch.tensor([4, 5, 6])

assert torch.allclose(curve[0], torch.tensor([1, 2, 3]))
assert torch.allclose(curve[5], torch.tensor([4, 5, 6]))

# Test tensor interpolation at curve bounds
# def test_tensor_interpolation_bounds():
# curve = Curve({0: torch.tensor([0, 0]), 10: torch.tensor([10, 10])}, default_interpolation='linear')

# assert torch.allclose(curve[-5], torch.tensor([0, 0])) # Test extrapolation if your curve supports it
# assert torch.allclose(curve[15], torch.tensor([10, 10]))

# Add more tests for edge cases and different tensor shapes if necessary

0 comments on commit 25b68fb

Please sign in to comment.