diff --git a/agentMET4FOF/streams/base_streams.py b/agentMET4FOF/streams/base_streams.py index 92a0d5ed..52dae6ad 100644 --- a/agentMET4FOF/streams/base_streams.py +++ b/agentMET4FOF/streams/base_streams.py @@ -5,7 +5,8 @@ from time_series_metadata.scheme import MetaData import warnings -class DataStreamMET4FOF(): + +class DataStreamMET4FOF: """ Abstract class for creating datastreams. @@ -79,10 +80,10 @@ def __init__(self): self._current_sample_quantities: Union[List, DataFrame, np.ndarray] self._current_sample_target: Union[List, DataFrame, np.ndarray] self._current_sample_time: Union[List, DataFrame, np.ndarray] - self._sample_idx: int = 0 #current sample index - self._n_samples: int = 0 #total number of samples + self._sample_idx: int = 0 # current sample index + self._n_samples: int = 0 # total number of samples self._data_source_type: str = "function" - self._generator_function : Callable + self._generator_function: Callable self._generator_parameters: Dict = {} self.sfreq: int = 1 self._metadata: MetaData @@ -105,7 +106,10 @@ def randomize_data(self): np.random.shuffle(random_index) self._quantities = self._quantities[random_index] - if type(self._target).__name__ == "ndarray" or type(self._target).__name__ == "list": + if ( + type(self._target).__name__ == "ndarray" + or type(self._target).__name__ == "list" + ): self._target = self._target[random_index] elif type(self._target).__name__ == "DataFrame": self._target = self._target.iloc[random_index] @@ -119,13 +123,13 @@ def sample_idx(self): return self._sample_idx def set_metadata( - self, - device_id: str, - time_name: str, - time_unit: str, - quantity_names: Union[str, Tuple[str, ...]], - quantity_units: Union[str, Tuple[str, ...]], - misc: Optional[Any] = None + self, + device_id: str, + time_name: str, + time_unit: str, + quantity_names: Union[str, Tuple[str, ...]], + quantity_units: Union[str, Tuple[str, ...]], + misc: Optional[Any] = None, ): """Set the quantities metadata as a ``MetaData`` object @@ -154,7 +158,7 @@ def set_metadata( time_unit=time_unit, quantity_names=quantity_names, quantity_units=quantity_units, - misc=misc + misc=misc, ) def _default_generator_function(self, time): @@ -164,11 +168,11 @@ def _default_generator_function(self, time): ---------- time : Union[List, DataFrame, np.ndarray] """ - value = np.sin(2*np.pi*self.F*time) + value = np.sin(2 * np.pi * self.F * time) return value def set_generator_function( - self, generator_function: Callable = None, sfreq: int = None, **kwargs: Any + self, generator_function: Callable = None, sfreq: int = None, **kwargs: Any ): """ Sets the data source to a generator function. By default, this function resorts @@ -191,14 +195,14 @@ def set_generator_function( The generator function call for every sample will be supplied with the ``**generator_parameters``. """ - #save the kwargs into generator_parameters + # save the kwargs into generator_parameters self._generator_parameters = kwargs if sfreq is not None: self.sfreq = sfreq self._set_data_source_type("function") - #resort to default wave generator if one is not supplied + # resort to default wave generator if one is not supplied if generator_function is None: warnings.warn( "No uncertainty generator function specified. Setting to default (" @@ -214,21 +218,20 @@ def _next_sample_generator(self, batch_size: int = 1) -> Dict[str, np.ndarray]: """ Internal method for generating a batch of samples from the generator function. """ - time: np.ndarray = np.arange(self._sample_idx, self._sample_idx + batch_size, - 1)/self.sfreq + time: np.ndarray = ( + np.arange(self._sample_idx, self._sample_idx + batch_size, 1) / self.sfreq + ) self._sample_idx += batch_size - value: np.ndarray = self._generator_function( - time, **self._generator_parameters - ) + value: np.ndarray = self._generator_function(time, **self._generator_parameters) - return {'quantities': value, 'time': time} + return {"quantities": value, "time": time} def set_data_source( - self, - quantities: Union[List, DataFrame, np.ndarray]=None, - target: Optional[Union[List, DataFrame, np.ndarray]]=None, - time: Optional[Union[List, DataFrame, np.ndarray]]=None + self, + quantities: Union[List, DataFrame, np.ndarray] = None, + target: Optional[Union[List, DataFrame, np.ndarray]] = None, + time: Optional[Union[List, DataFrame, np.ndarray]] = None, ): """ This sets the data source by providing up to three iterables: ``quantities`` , @@ -269,10 +272,10 @@ def set_data_source( self._target = target self._time = time - #infer number of samples + # infer number of samples if type(self._quantities).__name__ == "list": self._n_samples = len(self._quantities) - elif type(self._quantities).__name__ == "DataFrame": #dataframe or numpy + elif type(self._quantities).__name__ == "DataFrame": # dataframe or numpy self._quantities = self._quantities.to_numpy() self._n_samples = self._quantities.shape[0] elif type(self._quantities).__name__ == "ndarray": @@ -312,9 +315,9 @@ def next_sample(self, batch_size: int = 1): 'target':current_sample_target}`` """ - if self._data_source_type == 'function': + if self._data_source_type == "function": return self._next_sample_generator(batch_size) - elif self._data_source_type == 'dataset': + elif self._data_source_type == "dataset": return self._next_sample_data_source(batch_size) def _next_sample_data_source( @@ -340,18 +343,23 @@ def _next_sample_data_source( self._sample_idx += batch_size try: - self._current_sample_quantities = self._quantities[self._sample_idx - batch_size:self._sample_idx] + self._current_sample_quantities = self._quantities[ + self._sample_idx - batch_size : self._sample_idx + ] - #if target is available + # if target is available if self._target is not None: - self._current_sample_target = self._target[self._sample_idx - batch_size:self._sample_idx] + self._current_sample_target = self._target[ + self._sample_idx - batch_size : self._sample_idx + ] else: self._current_sample_target = None - #if time is available + # if time is available if self._time is not None: - self._current_sample_time = self._time[self._sample_idx - batch_size - :self._sample_idx] + self._current_sample_time = self._time[ + self._sample_idx - batch_size : self._sample_idx + ] else: self._current_sample_time = None except IndexError: @@ -359,7 +367,11 @@ def _next_sample_data_source( self._current_sample_target = None self._current_sample_time = None - return {'time':self._current_sample_time, 'quantities': self._current_sample_quantities, 'target': self._current_sample_target} + return { + "time": self._current_sample_time, + "quantities": self._current_sample_quantities, + "target": self._current_sample_target, + } def reset(self): self._sample_idx = 0 @@ -380,13 +392,12 @@ def extract_x_y(message): Handle data structures of dictionary to extract features & target """ - if type(message['data']) == tuple: - x = message['data'][0] - y = message['data'][1] - elif type(message['data']) == dict: - x = message['data']['x'] - y = message['data']['y'] + if type(message["data"]) == tuple: + x = message["data"][0] + y = message["data"][1] + elif type(message["data"]) == dict: + x = message["data"]["x"] + y = message["data"]["y"] else: return 1 return x, y -