Skip to content

Commit

Permalink
aigc
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx committed Dec 8, 2023
1 parent 495e4dc commit 55ac5fe
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
29 changes: 28 additions & 1 deletion src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from sortedcontainers import SortedDict
from typing import Tuple, Optional, Union, Dict, Callable

import numpy as np
import torch

from .interpolation import (
bisect_left_keyframe,
INTERPOLATORS,
Expand Down Expand Up @@ -97,7 +100,15 @@ def __init__(
interpolator_arguments=None,
):
self.t=t
self.value=value
#self.value=value
### <chatgpt>
if isinstance(value, np.ndarray):
self.value = np.array(value) # Ensure a copy of the array is stored
elif isinstance(value, torch.Tensor):
self.value = value.clone() # Ensure a clone of the tensor is stored
else:
self.value = value
### </chatgpt>
self.interpolation_method=interpolation_method
if interpolator_arguments is None:
interpolator_arguments = {}
Expand All @@ -110,6 +121,13 @@ def interpolator_arguments(self):
return {}

def __eq__(self, other) -> bool:
### <chatgpt>
if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
if isinstance(self.value, np.ndarray):
return np.array_equal(self.value, np.array(other))
else:
return torch.equal(self.value, torch.tensor(other))
### </chatgpt>
return self.value == other
def __repr__(self) -> str:
#d = f"Keyframe(t={self.t}, value={self.value}, interpolation_method='{self.interpolation_method}')"
Expand All @@ -119,6 +137,15 @@ def _to_dict(self, *args, **kwargs) -> dict:
d = {'t':self.t, 'value':self.value, 'interpolation_method':self.interpolation_method}
if self.interpolator_arguments:
d['interpolator_arguments'] = self.interpolator_arguments
### <chatgpt>
# Ensure the representation of numpy arrays and tensors are handled correctly
if isinstance(self.value, np.ndarray):
d['value'] = self.value.tolist()
elif isinstance(self.value, torch.Tensor):
d['value'] = self.value.numpy().tolist()
else:
d['value'] = self.value
### </chatgpt>
return d
def _to_tuple(self, *args, **kwags):
if not self.interpolator_arguments:
Expand Down
15 changes: 14 additions & 1 deletion src/keyframed/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
from numbers import Number
from typing import Callable

import numpy as np
import torch

def bisect_left_keyframe(k: Number, curve:'Curve', *args, **kargs) -> 'Keyframe':
"""
Expand Down Expand Up @@ -65,6 +66,18 @@ def linear(k, curve, *args, **kargs):
right = bisect_right_keyframe(k, curve)
x0, x1 = left.t, right.t
y0, y1 = left.value, right.value

### <chatgpt>
# Handle both NumPy arrays and PyTorch tensors
if isinstance(y0, (np.ndarray, torch.Tensor)) and isinstance(y1, (np.ndarray, torch.Tensor)):
d = x1 - x0
t = (x1 - k) / d
if isinstance(y0, np.ndarray):
return t * y0 + (1 - t) * y1
else:
return t * y0 + (1 - t) * y1
### </chatgpt>

d = x1-x0
t = (x1-k)/d
outv = t*y0 + (1-t)*y1
Expand Down

0 comments on commit 55ac5fe

Please sign in to comment.