Skip to content

Commit

Permalink
change torch and numpy to optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx committed Dec 13, 2023
1 parent d6d345f commit 98b0940
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
'omegaconf',
'matplotlib',
#############
'numpy',
'torch'
#'numpy',
#'torch'
],
extras_require={
'dev': [
Expand Down
64 changes: 48 additions & 16 deletions src/keyframed/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,35 @@
from sortedcontainers import SortedDict
from typing import Tuple, Optional, Union, Dict, Callable

import numpy as np
import torch

from .interpolation import (
bisect_left_keyframe,
INTERPOLATORS,
)
from .utils import id_generator, DictValuesArithmeticFriendly

def is_torch_tensor(obj):
try:
import torch
return isinstance(obj, torch.Tensor)
except ImportError:
pass
return False

def is_numpy_ndarray(obj):
try:
import numpy as np
return isinstance(obj, np.ndarray)
except ImportError:
pass
return False

def numpy_array_equal(a,b):
import numpy as np
return np.array_equal(a,b)

def torch_isequal(a,b):
import torch
return torch.equal(a,b)

# workhorse of Curve.__init__, should probably attach it as an instance method on Curve
def ensure_sorteddict_of_keyframes(
Expand All @@ -32,7 +52,8 @@ def ensure_sorteddict_of_keyframes(
sorteddict = curve
elif isinstance(curve, dict):
sorteddict = SortedDict(curve)
elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
#elif isinstance(curve, (Number, np.ndarray, torch.Tensor)):
elif (isinstance(curve, Number) or is_numpy_ndarray(curve) or is_torch_tensor(curve)):
sorteddict = SortedDict({0:Keyframe(t=0,value=curve, interpolation_method=default_interpolation, interpolator_arguments=default_interpolator_args)})
elif (isinstance(curve, list) or isinstance(curve, tuple)):
d_ = {}
Expand Down Expand Up @@ -81,7 +102,8 @@ def ensure_sorteddict_of_keyframes(
implied_interpolation = kf.interpolation_method
implied_interpolator_args = kf.interpolator_arguments
d_[k] = kf
elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
#elif isinstance(v, (Number, np.ndarray, torch.Tensor)):
elif (isinstance(v, Number) or is_numpy_ndarray(v) or is_torch_tensor(v)):
d_[k] = Keyframe(t=k,value=v, interpolation_method=implied_interpolation, interpolator_arguments=implied_interpolator_args)
else:
raise NotImplementedError
Expand All @@ -102,9 +124,12 @@ def __init__(
self.t=t
#self.value=value
### <chatgpt>
if isinstance(value, np.ndarray):
self.value = np.array(value) # Ensure a copy of the array is stored
elif isinstance(value, torch.Tensor):
#if isinstance(value, np.ndarray):
if is_numpy_ndarray(value):
#self.value = np.array(value) # Ensure a copy of the array is stored
self.value = deepcopy(value)
#elif isinstance(value, torch.Tensor):
elif is_torch_tensor(value):
self.value = value.clone() # Ensure a clone of the tensor is stored
else:
self.value = value
Expand All @@ -122,13 +147,18 @@ def interpolator_arguments(self):

def __eq__(self, other) -> bool:
### <chatgpt>
if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
if isinstance(self.value, np.ndarray):
return np.array_equal(self.value, np.array(other))
else:
return torch.equal(self.value, torch.tensor(other))
#if isinstance(self.value, (np.ndarray, torch.Tensor)) and isinstance(other, (np.ndarray, torch.Tensor)):
# if isinstance(self.value, np.ndarray):
# return np.array_equal(self.value, np.array(other))
# else:
# return torch.equal(self.value, torch.tensor(other))
### </chatgpt>
return self.value == other
if is_numpy_ndarray(self.value):
return numpy_array_equal(self.value, other)
elif is_torch_tensor(self.value):
return torch_isequal(self.value, other)
else:
return self.value == other
def __repr__(self) -> str:
#d = f"Keyframe(t={self.t}, value={self.value}, interpolation_method='{self.interpolation_method}')"
d = self.to_dict()
Expand All @@ -139,9 +169,11 @@ def _to_dict(self, *args, **kwargs) -> dict:
d['interpolator_arguments'] = self.interpolator_arguments
### <chatgpt>
# Ensure the representation of numpy arrays and tensors are handled correctly
if isinstance(self.value, np.ndarray):
#if isinstance(self.value, np.ndarray):
if is_numpy_ndarray(self.value):
d['value'] = self.value.tolist()
elif isinstance(self.value, torch.Tensor):
#elif isinstance(self.value, torch.Tensor):
elif is_torch_tensor(self.value):
d['value'] = self.value.numpy().tolist()
else:
d['value'] = self.value
Expand Down
4 changes: 2 additions & 2 deletions src/keyframed/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
from numbers import Number
from typing import Callable
import numpy as np
import torch
#import numpy as np
#import torch
from functools import partial

def bisect_left_keyframe(k: Number, curve:'Curve', *args, **kargs) -> 'Keyframe':
Expand Down

0 comments on commit 98b0940

Please sign in to comment.