Skip to content

Commit

Permalink
tensors! suppress tests for hallucinated features
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx committed Dec 8, 2023
1 parent 55ac5fe commit 690d1f3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ensure_sorteddict_of_keyframes(
sorteddict = curve
elif isinstance(curve, dict):
sorteddict = SortedDict(curve)
elif isinstance(curve, Number):
elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
sorteddict = SortedDict({0:Keyframe(t=0,value=curve, interpolation_method=default_interpolation, interpolator_arguments=default_interpolator_args)})
elif (isinstance(curve, list) or isinstance(curve, tuple)):
d_ = {}
Expand Down Expand Up @@ -81,7 +81,7 @@ def ensure_sorteddict_of_keyframes(
implied_interpolation = kf.interpolation_method
implied_interpolator_args = kf.interpolator_arguments
d_[k] = kf
elif isinstance(v, Number):
elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
d_[k] = Keyframe(t=k,value=v, interpolation_method=implied_interpolation, interpolator_arguments=implied_interpolator_args)
else:
raise NotImplementedError
Expand Down
43 changes: 22 additions & 21 deletions tests/test_numpy_vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from keyframed import Curve

def test_vector_interpolation_linear():
# Test linear interpolation with vectors
Expand All @@ -10,19 +11,19 @@ def test_vector_interpolation_linear():
assert np.allclose(curve[2.5], np.array([2.5, 2.5, 2.5]))
assert np.allclose(curve[7.5], np.array([7.5, 7.5, 7.5]))

def test_vector_interpolation_custom():
# Test custom interpolation function for vectors
def custom_interp(t, t0, value0, t1, value1):
# Example: simple linear interpolation
return value0 + (value1 - value0) * (t - t0) / (t1 - t0)
# def test_vector_interpolation_custom():
# # Test custom interpolation function for vectors
# def custom_interp(t, t0, value0, t1, value1):
# # Example: simple linear interpolation
# return value0 + (value1 - value0) * (t - t0) / (t1 - t0)

start_vec = np.array([1, 2, 3])
end_vec = np.array([4, 5, 6])
curve = Curve({0: start_vec, 10: end_vec})
curve.set_interpolation(custom_interp)
# start_vec = np.array([1, 2, 3])
# end_vec = np.array([4, 5, 6])
# curve = Curve({0: start_vec, 10: end_vec})
# curve.set_interpolation(custom_interp)

assert np.allclose(curve[5], np.array([2.5, 3.5, 4.5]))
assert np.allclose(curve[2], np.array([1.3, 2.3, 3.3]))
# assert np.allclose(curve[5], np.array([2.5, 3.5, 4.5]))
# assert np.allclose(curve[2], np.array([1.3, 2.3, 3.3]))

def test_vector_keyframe_insertion():
# Test inserting vector keyframes
Expand All @@ -33,18 +34,18 @@ def test_vector_keyframe_insertion():
assert np.allclose(curve[0], np.array([1, 2, 3]))
assert np.allclose(curve[5], np.array([4, 5, 6]))

def test_vector_interpolation_bounds():
# Test vector interpolation at curve bounds
curve = Curve({0: np.array([0, 0]), 10: np.array([10, 10])}, default_interpolation='linear')
# def test_vector_interpolation_bounds():
# # Test vector interpolation at curve bounds
# curve = Curve({0: np.array([0, 0]), 10: np.array([10, 10])}, default_interpolation='linear')

assert np.allclose(curve[-5], np.array([0, 0])) # Test extrapolation if your curve supports it
assert np.allclose(curve[15], np.array([10, 10]))
# assert np.allclose(curve[-5], np.array([0, 0])) # Test extrapolation if your curve supports it
# assert np.allclose(curve[15], np.array([10, 10]))

def test_vector_mixed_interpolation():
# Test curves with mixed scalar and vector interpolation
curve = Curve({0: 0, 5: np.array([5, 10, 15]), 10: 10}, default_interpolation='linear')
# def test_vector_mixed_interpolation():
# # Test curves with mixed scalar and vector interpolation
# curve = Curve({0: 0, 5: np.array([5, 10, 15]), 10: 10}, default_interpolation='linear')

assert curve[2.5] == 2.5 # Scalar interpolation
assert np.allclose(curve[7.5], np.array([7.5, 12.5, 17.5])) # Vector interpolation
# assert curve[2.5] == 2.5 # Scalar interpolation
# assert np.allclose(curve[7.5], np.array([7.5, 12.5, 17.5])) # Vector interpolation

# Add more tests for edge cases, different vector lengths, and other interpolation methods as needed
48 changes: 24 additions & 24 deletions tests/test_torch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,35 @@

# Test linear interpolation with PyTorch tensors
def test_linear_interpolation_tensor():
start_tensor = torch.tensor([0, 0, 0])
end_tensor = torch.tensor([10, 10, 10])
start_tensor = torch.tensor([0, 0, 0], dtype=torch.float32)
end_tensor = torch.tensor([10, 10, 10], dtype=torch.float32)
curve = Curve({0: start_tensor, 10: end_tensor}, default_interpolation='linear')

assert torch.allclose(curve[5], torch.tensor([5, 5, 5]))
assert torch.allclose(curve[2.5], torch.tensor([2.5, 2.5, 2.5]))
assert torch.allclose(curve[7.5], torch.tensor([7.5, 7.5, 7.5]))
assert torch.allclose(curve[5], torch.tensor([5, 5, 5], dtype=torch.float32))
assert torch.allclose(curve[2.5], torch.tensor([2.5, 2.5, 2.5], dtype=torch.float32))
assert torch.allclose(curve[7.5], torch.tensor([7.5, 7.5, 7.5], dtype=torch.float32))

# Test mixed NumPy array and PyTorch tensor interpolation
def test_mixed_interpolation():
start_array = np.array([1, 2, 3])
end_tensor = torch.tensor([4, 5, 6])
curve = Curve({0: start_array, 10: end_tensor}, default_interpolation='linear')
# def test_mixed_interpolation():
# start_array = np.array([1, 2, 3])
# end_tensor = torch.tensor([4, 5, 6])
# curve = Curve({0: start_array, 10: end_tensor}, default_interpolation='linear')

expected_midpoint = torch.tensor([2.5, 3.5, 4.5])
assert torch.allclose(curve[5], expected_midpoint)
# expected_midpoint = torch.tensor([2.5, 3.5, 4.5])
# assert torch.allclose(curve[5], expected_midpoint)

# Test PyTorch tensor interpolation with custom interpolation method
def test_custom_tensor_interpolation():
def custom_interp(t, t0, value0, t1, value1):
value0, value1 = torch.tensor(value0), torch.tensor(value1)
return value0 + (value1 - value0) * (t - t0) / (t1 - t0)
# def test_custom_tensor_interpolation():
# def custom_interp(t, t0, value0, t1, value1):
# value0, value1 = torch.tensor(value0), torch.tensor(value1)
# return value0 + (value1 - value0) * (t - t0) / (t1 - t0)

start_tensor = torch.tensor([1, 1, 1])
end_tensor = torch.tensor([2, 2, 2])
curve = Curve({0: start_tensor, 1: end_tensor})
curve.set_interpolation(custom_interp)
# start_tensor = torch.tensor([1, 1, 1])
# end_tensor = torch.tensor([2, 2, 2])
# curve = Curve({0: start_tensor, 1: end_tensor})
# curve.set_interpolation(custom_interp)

assert torch.allclose(curve[0.5], torch.tensor([1.5, 1.5, 1.5]))
# assert torch.allclose(curve[0.5], torch.tensor([1.5, 1.5, 1.5]))

# Test tensor keyframe insertion
def test_tensor_keyframe_insertion():
Expand All @@ -45,10 +45,10 @@ def test_tensor_keyframe_insertion():
assert torch.allclose(curve[5], torch.tensor([4, 5, 6]))

# Test tensor interpolation at curve bounds
def test_tensor_interpolation_bounds():
curve = Curve({0: torch.tensor([0, 0]), 10: torch.tensor([10, 10])}, default_interpolation='linear')
# def test_tensor_interpolation_bounds():
# curve = Curve({0: torch.tensor([0, 0]), 10: torch.tensor([10, 10])}, default_interpolation='linear')

assert torch.allclose(curve[-5], torch.tensor([0, 0])) # Test extrapolation if your curve supports it
assert torch.allclose(curve[15], torch.tensor([10, 10]))
# assert torch.allclose(curve[-5], torch.tensor([0, 0])) # Test extrapolation if your curve supports it
# assert torch.allclose(curve[15], torch.tensor([10, 10]))

# Add more tests for edge cases and different tensor shapes if necessary

0 comments on commit 690d1f3

Please sign in to comment.