diff --git a/docs/full/lib/training.rst b/docs/full/lib/training.rst index ff8c3cde04..c23c617583 100644 --- a/docs/full/lib/training.rst +++ b/docs/full/lib/training.rst @@ -15,6 +15,15 @@ augmentation module :undoc-members: :show-inheritance: +cache module +============ + +.. automodule:: lib.training.cache + :members: + :undoc-members: + :show-inheritance: + + generator module ================ diff --git a/lib/align/detected_face.py b/lib/align/detected_face.py index 04ada9b568..eec3ca4fbf 100644 --- a/lib/align/detected_face.py +++ b/lib/align/detected_face.py @@ -106,6 +106,7 @@ def __init__(self, self._landmarks_xy = landmarks_xy self.thumbnail: Optional[np.ndarray] = None self.mask = {} if mask is None else mask + self._training_masks: Optional[Tuple[bytes, Tuple[int, int, int]]] = None self._aligned: Optional[AlignedFace] = None logger.trace("Initialized %s", self.__class__.__name__) # type: ignore @@ -172,59 +173,84 @@ def add_mask(self, fsmask.add(mask, affine_matrix, interpolator) self.mask[name] = fsmask - def get_landmark_mask(self, size, area, - aligned=True, centering="face", dilation=0, blur_kernel=0, as_zip=False): - """ Obtain a single channel mask based on the face's landmark points. + def get_landmark_mask(self, + area: Literal["eye", "face", "mouth"], + blur_kernel: int, + dilation: int) -> np.ndarray: + """ Add a :class:`LandmarksMask` to this detected face + + Landmark based masks are generated from face Aligned Face landmark points. An aligned + face must be loaded. As the data is coming from the already aligned face, no further mask + cropping is required. Parameters ---------- - size: int or tuple - The size of the aligned mask to retrieve. Should be an `int` if an aligned face is - being requested, or a ('height', 'width') shape tuple if a full frame is being - requested - area: ["mouth", "eyes"] + area: ["face", "mouth", "eye"] The type of mask to obtain. `face` is a full face mask the others are masks for those specific areas - aligned: bool, optional - ``True`` if the returned mask should be for an aligned face. ``False`` if a full frame - mask should be returned. Default ``True`` - centering: ["legacy", "face", "head"], optional - Only used if `aligned`=``True``. The centering for the landmarks based mask. Should be - the same as the centering used for the extracted face that this mask will be applied - to. "legacy" places the nose in the center of the image (the original method for - aligning). "face" aligns for the nose to be in the center of the face (top to bottom) - but the center of the skull for left to right. "head" aligns for the center of the - skull (in 3D space) being the center of the extracted image, with the crop holding the - full head. Default: `"face"` - dilation: int, optional + blur_kernel: int + The size of the kernel for blurring the mask edges + dilation: int The amount of dilation to apply to the mask. `0` for none. Default: `0` - blur_kernel: int, optional - The kernel size for applying gaussian blur to apply to the mask. `0` for none. - Default: `0` - as_zip: bool, optional - ``True`` if the mask should be returned zipped otherwise ``False`` Returns ------- - :class:`numpy.ndarray` or zipped array - The mask as a single channel image of the given :attr:`size` dimension. If - :attr:`as_zip` is ``True`` then the :class:`numpy.ndarray` will be contained within a - zipped container + :class:`numpy.ndarray` + The generated landmarks mask for the selected area """ # TODO Face mask generation from landmarks - logger.trace("size: %s, area: %s, aligned: %s, dilation: %s, blur_kernel: %s, as_zip: %s", - size, area, aligned, dilation, blur_kernel, as_zip) - areas = dict(mouth=[slice(48, 60)], eyes=[slice(36, 42), slice(42, 48)]) - if aligned: - face = AlignedFace(self.landmarks_xy, centering=centering, size=size) - landmarks = face.landmarks - size = (size, size) - else: - landmarks = self.landmarks_xy - points = [landmarks[zone] for zone in areas[area]] # pylint:disable=unsubscriptable-object - mask = _LandmarksMask(size, points, dilation=dilation, blur_kernel=blur_kernel) - retval = mask.get(as_zip=as_zip) - return retval + logger.trace("area: %s, dilation: %s", area, dilation) # type: ignore + areas = dict(mouth=[slice(48, 60)], eye=[slice(36, 42), slice(42, 48)]) + points = [self.aligned.landmarks[zone] + for zone in areas[area]] + + lmmask = LandmarksMask(points, + storage_size=self.aligned.size, + storage_centering=self.aligned.centering, + dilation=dilation) + lmmask.set_blur_and_threshold(blur_kernel=blur_kernel) + lmmask.generate_mask( + self.aligned.adjusted_matrix, + self.aligned.interpolators[1]) + return lmmask.mask + + def store_training_masks(self, + masks: List[Optional[np.ndarray]], + delete_masks: bool = False) -> None: + """ Concatenate and compress the given training masks and store for retrieval. + + Parameters + ---------- + masks: list + A list of training mask. Must be all be uint-8 3D arrays of the same size in + 0-255 range + delete_masks: bool, optional + ``True`` to delete any of the :class:`Mask` objects owned by this detected face. Use to + free up unrequired memory usage. Default: ``False`` + """ + if delete_masks: + del self.mask + self.mask = {} + + valid = [msk for msk in masks if msk is not None] + if not valid: + return + combined = np.concatenate(valid, axis=-1) + self._training_masks = (compress(combined), combined.shape) + + def get_training_masks(self) -> Optional[np.ndarray]: + """ Obtain the decompressed combined training masks. + + Returns + ------- + :class:`numpy.ndarray` + A 3D array containing the decompressed training masks as uint8 in 0-255 range if + training masks are present otherwise ``None`` + """ + if not self._training_masks: + return None + return np.frombuffer(decompress(self._training_masks[0]), + dtype="uint8").reshape(self._training_masks[1]) def to_alignment(self) -> AlignmentFileDict: """ Return the detected face formatted for an alignments file @@ -412,77 +438,6 @@ def load_aligned(self, is_legacy=is_aligned and is_legacy) -class _LandmarksMask(): # pylint:disable=too-few-public-methods - """ Create a single channel mask from aligned landmark points. - - size: tuple - The (height, width) shape tuple that the mask should be returned as - points: list - A list of landmark points that correspond to the given shape tuple to create - the mask. Each item in the list should be a :class:`numpy.ndarray` that a filled - convex polygon will be created from - dilation: int, optional - The amount of dilation to apply to the mask. `0` for none. Default: `0` - blur_kernel: int, optional - The kernel size for applying gaussian blur to apply to the mask. `0` for none. Default: `0` - """ - def __init__(self, size, points, dilation=0, blur_kernel=0): - logger.trace("Initializing: %s: (size: %s, points: %s, dilation: %s, blur_kernel: %s)", - self.__class__.__name__, size, points, dilation, blur_kernel) - self._size = size - self._points = points - self._dilation = dilation - self._blur_kernel = blur_kernel - self._mask = None - logger.trace("Initialized: %s", self.__class__.__name__) - - def get(self, as_zip=False): - """ Obtain the mask. - - Parameters - ---------- - as_zip: bool, optional - ``True`` if the mask should be returned zipped otherwise ``False`` - - Returns - ------- - :class:`numpy.ndarray` or zipped array - The mask as a single channel image of the given :attr:`size` dimension. If - :attr:`as_zip` is ``True`` then the :class:`numpy.ndarray` will be contained within a - zipped container - """ - if not np.any(self._mask): - self._generate_mask() - retval = compress(self._mask) if as_zip else self._mask - logger.trace("as_zip: %s, retval type: %s", as_zip, type(retval)) - return retval - - def _generate_mask(self): - """ Generate the mask. - - Creates the mask applying any requested dilation and blurring and assigns to - :attr:`_mask` - - Returns - ------- - :class:`numpy.ndarray` - The mask as a single channel image of the given :attr:`size` dimension. - """ - mask = np.zeros((self._size) + (1, ), dtype="float32") - for landmarks in self._points: - lms = np.rint(landmarks).astype("int") - cv2.fillConvexPoly(mask, cv2.convexHull(lms), 1.0, lineType=cv2.LINE_AA) - if self._dilation != 0: - mask = cv2.dilate(mask, - cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (self._dilation, self._dilation)), - iterations=1) - if self._blur_kernel != 0: - mask = BlurMask("gaussian", mask, self._blur_kernel).blurred - logger.trace("mask: (shape: %s, dtype: %s)", mask.shape, mask.dtype) - self._mask = (mask * 255.0).astype("uint8") - - class Mask(): """ Face Mask information and convenience methods @@ -805,6 +760,79 @@ def from_dict(self, mask_dict: MaskAlignmentsFileDict) -> None: for k, v in mask_dict.items()}) +class LandmarksMask(Mask): + """ Create a single channel mask from aligned landmark points. + + Landmarks masks are created on the fly, so the stored centering and size should be the same as + the aligned face that the mask will be applied to. As the masks are created on the fly, blur + + dilation is applied to the mask at creation (prior to compression) rather than after + decompression when requested. + + Note + ---- + Threshold is not used for Landmarks mask as the mask is binary + + Parameters + ---------- + points: list + A list of landmark points that correspond to the given storage_size to create + the mask. Each item in the list should be a :class:`numpy.ndarray` that a filled + convex polygon will be created from + storage_size: int, optional + The size (in pixels) that the compressed mask should be stored at. Default: 128. + storage_centering, str (optional): + The centering to store the mask at. One of `"legacy"`, `"face"`, `"head"`. + Default: `"face"` + dilation: int, optional + The amount of dilation to apply to the mask. `0` for none. Default: `0` + """ + def __init__(self, + points: List[np.ndarray], + storage_size: int = 128, + storage_centering: "CenteringType" = "face", + dilation: int = 0) -> None: + super().__init__(storage_size=storage_size, storage_centering=storage_centering) + self._points = points + self._dilation = dilation + + @property + def mask(self) -> np.ndarray: + """ :class:`numpy.ndarray`: Overrides the default mask property, creating the processed + mask at first call and compressing it. The decompressed mask is returned from this + property. """ + return self.stored_mask + + def generate_mask(self, affine_matrix: np.ndarray, interpolator: int) -> None: + """ Generate the mask. + + Creates the mask applying any requested dilation and blurring and assigns compressed mask + to :attr:`_mask` + + Parameters + ---------- + affine_matrix: :class:`numpy.ndarray` + The transformation matrix required to transform the mask to the original frame. + interpolator, int: + The CV2 interpolator required to transform this mask to it's original frame + """ + mask = np.zeros((self.stored_size, self.stored_size, 1), dtype="float32") + for landmarks in self._points: + lms = np.rint(landmarks).astype("int") + cv2.fillConvexPoly(mask, cv2.convexHull(lms), 1.0, lineType=cv2.LINE_AA) + if self._dilation != 0: + mask = cv2.dilate(mask, + cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (self._dilation, self._dilation)), + iterations=1) + if self._blur_kernel != 0 and self._blur_type is not None: + mask = BlurMask(self._blur_type, + mask, + self._blur_kernel, + passes=self._blur_passes).blurred + logger.trace("mask: (shape: %s, dtype: %s)", mask.shape, mask.dtype) # type: ignore + self.add(mask, affine_matrix, interpolator) + + class BlurMask(): # pylint:disable=too-few-public-methods """ Factory class to return the correct blur object for requested blur type. diff --git a/lib/cli/launcher.py b/lib/cli/launcher.py index 1171a43b68..15fb3c27d9 100644 --- a/lib/cli/launcher.py +++ b/lib/cli/launcher.py @@ -42,6 +42,7 @@ def _import_script(self) -> Callable: class: Faceswap Script The uninitialized script from the faceswap scripts folder. """ + self._set_environment_variables() self._test_for_tf_version() self._test_for_gui() cmd = os.path.basename(sys.argv[0]) @@ -51,6 +52,17 @@ def _import_script(self) -> Callable: script = getattr(module, self._command.title()) return script + def _set_environment_variables(self) -> None: + """ Set the number of threads that numexpr can use and TF environment variables. """ + # Allocate a decent number of threads to numexpr to suppress warnings + cpu_count = os.cpu_count() + allocate = cpu_count - cpu_count // 3 if cpu_count is not None else 1 + os.environ["NUMEXPR_MAX_THREADS"] = str(max(1, allocate)) + + # Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library + os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4" + os.environ["KMP_AFFINITY"] = "disabled" + def _test_for_tf_version(self) -> None: """ Check that the required Tensorflow version is installed. @@ -63,9 +75,6 @@ def _test_for_tf_version(self) -> None: min_ver = 2.7 max_ver = 2.9 try: - # Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library - os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4" - os.environ["KMP_AFFINITY"] = "disabled" import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import except ImportError as err: if "DLL load failed while importing" in str(err): diff --git a/lib/image.py b/lib/image.py index c7dc31b244..6b72339f2b 100644 --- a/lib/image.py +++ b/lib/image.py @@ -356,26 +356,20 @@ def read_image_batch(filenames, with_metadata=False): >>> images = read_image_batch(image_filenames) """ logger.trace("Requested batch: '%s'", filenames) - executor = futures.ThreadPoolExecutor() - with executor: + batch = [None for _ in range(len(filenames))] + if with_metadata: + meta = [None for _ in range(len(filenames))] + + with futures.ThreadPoolExecutor() as executor: images = {executor.submit(read_image, filename, - raise_error=True, with_metadata=with_metadata): filename - for filename in filenames} - batch = [None for _ in range(len(filenames))] - if with_metadata: - meta = [None for _ in range(len(filenames))] - # There is no guarantee that the same filename will not be passed through multiple times - # (and when shuffle is true this can definitely happen), so we can't just call - # filenames.index(). - return_indices = {filename: [idx for idx, fname in enumerate(filenames) - if fname == filename] - for filename in set(filenames)} + raise_error=True, with_metadata=with_metadata): idx + for idx, filename in enumerate(filenames)} for future in futures.as_completed(images): - return_idx = return_indices[images[future]].pop() + ret_idx = images[future] if with_metadata: - batch[return_idx], meta[return_idx] = future.result() + batch[ret_idx], meta[ret_idx] = future.result() else: - batch[return_idx] = future.result() + batch[ret_idx] = future.result() batch = np.array(batch) retval = (batch, meta) if with_metadata else batch diff --git a/lib/training/__init__.py b/lib/training/__init__.py index 6b0d296658..b697eeff0f 100644 --- a/lib/training/__init__.py +++ b/lib/training/__init__.py @@ -3,4 +3,4 @@ associated objects. """ from .augmentation import ImageAugmentation # noqa -from .generator import TrainingDataGenerator # noqa +from .generator import PreviewDataGenerator, TrainingDataGenerator # noqa diff --git a/lib/training/augmentation.py b/lib/training/augmentation.py index 5b0eff2309..d5e44a800a 100644 --- a/lib/training/augmentation.py +++ b/lib/training/augmentation.py @@ -1,16 +1,67 @@ #!/usr/bin/env python3 """ Processes the augmentation of images for feeding into a Faceswap model. """ +from dataclasses import dataclass import logging +from typing import Tuple, TYPE_CHECKING import cv2 +import numexpr as ne import numpy as np from scipy.interpolate import griddata from lib.image import batch_convert_color +if TYPE_CHECKING: + from plugins.train.trainer._base import ConfigType + logger = logging.getLogger(__name__) # pylint: disable=invalid-name +@dataclass +class AugConstants: + """ Dataclass for holding constants for Image Augmentation. + + Paramaters + ---------- + clahe_base_contrast: int + The base number for Contrast Limited Adaptive Histogram Equalization + clahe_chance: float + Probability to perform Contrast Limited Adaptive Histogram Equilization + clahe_max_size: int + Maximum clahe window size + lab_adjust: np.ndarray + Adjustment amounts for L*A*B augmentation + transform_rotation: int + Rotation range for transformations + transform_zoom: float + Zoom range for transformations + transform_shift: float + Shift range for transformations + warp_maps: :class:`numpy.ndarray` + The stacked (x, y) mappings for image warping + warp_pads: tuple + The padding to apply for image warping + warp_slices: slice + The slices for extracting a warped image + warp_lm_edge_anchors: :class:`numpy.ndarray` + The edge anchors for landmark based warping + warp_lm_grids: :class:`numpy.ndarray` + The grids for landmark based warping + """ + clahe_base_contrast: int + clahe_chance: float + clahe_max_size: int + lab_adjust: np.ndarray + transform_rotation: int + transform_zoom: float + transform_shift: float + warp_maps: np.ndarray + warp_pad: Tuple[int, int] + warp_slices: slice + warp_lm_edge_anchors: np.ndarray + warp_lm_grids: np.ndarray + + class ImageAugmentation(): """ Performs augmentation on batches of training images. @@ -18,191 +69,81 @@ class ImageAugmentation(): ---------- batchsize: int The number of images that will be fed through the augmentation functions at once. - is_display: bool - Whether the images being fed through will be used for Preview or Time-lapse. Disables - the "warp" augmentation for these images. - input_size: int - The expected input size for the model. It is assumed that the input to the model is always - a square image. This is the size, in pixels, of the `width` and the `height` of the input - to the model. - output_shapes: list - A list of tuples defining the output shapes from the model, in the order that the outputs - are returned. The tuples should be in (`height`, `width`, `channels`) format. - coverage_ratio: float - The ratio of the training image to be trained on. Dictates how much of the image will be - cropped out. E.G: a coverage ratio of 0.625 will result in cropping a 160px box from a - 256px image (:math:`256 * 0.625 = 160`) + processing_size: int + The largest input or output size of the model. This is the size that images are processed + at. config: dict The configuration `dict` generated from :file:`config.train.ini` containing the trainer plugin configuration options. - - Attributes - ---------- - initialized: bool - Flag to indicate whether :class:`ImageAugmentation` has been initialized with the training - image size in order to cache certain augmentation operations (see :func:`initialize`) - is_display: bool - Flag to indicate whether these augmentations are for time-lapses/preview images (``True``) - or standard training data (``False``) """ - def __init__(self, batchsize, is_display, input_size, output_shapes, coverage_ratio, config): - logger.debug("Initializing %s: (batchsize: %s, is_display: %s, input_size: %s, " - "output_shapes: %s, coverage_ratio: %s, config: %s)", - self.__class__.__name__, batchsize, is_display, input_size, output_shapes, - coverage_ratio, config) - - self.initialized = False - self.is_display = is_display - - # Set on first image load from initialize - self._training_size = 0 - self._constants = None - + def __init__(self, + batchsize: int, + processing_size: int, + config: "ConfigType") -> None: + logger.debug("Initializing %s: (batchsize: %s, processing_size: %s, " + "config: %s)", + self.__class__.__name__, batchsize, processing_size, config) + + self._processing_size = processing_size self._batchsize = batchsize self._config = config - # Transform and Warp args - self._input_size = input_size - self._output_sizes = [shape[1] for shape in output_shapes if shape[2] == 3] - logger.debug("Output sizes: %s", self._output_sizes) + # Warp args - self._coverage_ratio = coverage_ratio - self._scale = 5 # Normal random variable scale + self._warp_scale = 5 / 256 * self._processing_size # Normal random variable scale + self._warp_lm_scale = 2 / 256 * self._processing_size # Normal random variable scale + self._constants = self._get_constants() logger.debug("Initialized %s", self.__class__.__name__) - def initialize(self, training_size): + def _get_constants(self) -> AugConstants: """ Initializes the caching of constants for use in various image augmentations. - The training image size is not known prior to loading the images from disk and commencing - training, so it cannot be set in the :func:`__init__` method. When the first training batch - is loaded this function should be called to initialize the class and perform various - calculations based on this input size to cache certain constants for image augmentation - calculations. + Returns + ------- + dict + Cached constants that are used for various augmentations + """ + logger.debug("Initializing constants.") - Parameters - ---------- - training_size: int - The size of the training images stored on disk that are to be fed into - :class:`ImageAugmentation`. The training images should always be square and of the - same size. This is the size, in pixels, of the `width` and the `height` of the - training images. - """ - logger.debug("Initializing constants. training_size: %s", training_size) - self._training_size = training_size - coverage = int(self._training_size * self._coverage_ratio // 2) * 2 + # Transform + tform_shift = (int(self._config.get("shift_range", 5)) / 100) * self._processing_size # Color Aug - clahe_base_contrast = training_size // 128 - # Target Images - tgt_slices = slice(self._training_size // 2 - coverage // 2, - self._training_size // 2 + coverage // 2) + amount_l = int(self._config.get("color_lightness", 30)) / 100 + amount_ab = int(self._config.get("color_ab", 8)) / 100 + lab_adjust = np.array([amount_l, amount_ab, amount_ab], dtype="float32") # Random Warp - warp_range_ = np.linspace(self._training_size // 2 - coverage // 2, - self._training_size // 2 + coverage // 2, 5, dtype='float32') - warp_mapx = np.broadcast_to(warp_range_, (self._batchsize, 5, 5)).astype("float32") + warp_range = np.linspace(0, self._processing_size, 5, dtype='float32') + warp_mapx = np.broadcast_to(warp_range, (self._batchsize, 5, 5)).astype("float32") warp_mapy = np.broadcast_to(warp_mapx[0].T, (self._batchsize, 5, 5)).astype("float32") - - warp_pad = int(1.25 * self._input_size) - warp_slices = slice(warp_pad // 10, -warp_pad // 10) + warp_pad = int(1.25 * self._processing_size) # Random Warp Landmarks - p_mx = self._training_size - 1 - p_hf = (self._training_size // 2) - 1 + p_mx = self._processing_size - 1 + p_hf = (self._processing_size // 2) - 1 edge_anchors = np.array([(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)]).astype("int32") edge_anchors = np.broadcast_to(edge_anchors, (self._batchsize, 8, 2)) - grids = np.mgrid[0:p_mx:complex(self._training_size), 0:p_mx:complex(self._training_size)] - - self._constants = dict(clahe_base_contrast=clahe_base_contrast, - tgt_slices=tgt_slices, - warp_mapx=warp_mapx, - warp_mapy=warp_mapy, - warp_pad=warp_pad, - warp_slices=warp_slices, - warp_lm_edge_anchors=edge_anchors, - warp_lm_grids=grids) - self.initialized = True - logger.debug("Initialized constants: %s", {k: str(v) if isinstance(v, np.ndarray) else v - for k, v in self._constants.items()}) - - # <<< TARGET IMAGES >>> # - def get_targets(self, batch): - """ Returns the target images, and masks, if required. - - Parameters - ---------- - batch: :class:`numpy.ndarray` - This should be a 4+-dimensional array of training images in the format (`batchsize`, - `height`, `width`, `channels`). Targets should be requested after performing image - transformations but prior to performing warps. - - The 4th channel should be the mask. Any channels above the 4th should be any additional - masks that are requested. - - Returns - ------- - dict - The following keys will be within the returned dictionary: - - * **targets** (`list`) - A list of 4-dimensional :class:`numpy.ndarray` s in the \ - order and size of each output of the model as defined in :attr:`output_shapes`. The \ - format of these arrays will be (`batchsize`, `height`, `width`, `3`). **NB:** \ - masks are not included in the `targets` list. If masks are to be included in the \ - output they will be returned as their own item from the `masks` key. - - * **masks** (:class:`numpy.ndarray`) - A 4-dimensional array containing the target \ - masks in the format (`batchsize`, `height`, `width`, `1`). - """ - logger.trace("Compiling targets: batch shape: %s", batch.shape) - slices = self._constants["tgt_slices"] - target_batch = [np.array([cv2.resize(image[slices, slices, :], - (size, size), - cv2.INTER_AREA) - for image in batch], dtype='float32') / 255. - for size in self._output_sizes] - logger.trace("Target image shapes: %s", - [tgt_images.shape for tgt_images in target_batch]) - - retval = self._separate_target_mask(target_batch) - logger.trace("Final targets: %s", - {k: v.shape if isinstance(v, np.ndarray) else [img.shape for img in v] - for k, v in retval.items()}) - return retval - - @staticmethod - def _separate_target_mask(target_batch): - """ Return the batch and the batch of final masks - - Parameters - ---------- - target_batch: list - List of 4 dimension :class:`numpy.ndarray` objects resized the model outputs. - The 4th channel of the array contains the face mask, any additional channels after - this are additional masks (e.g. eye mask and mouth mask) - - Returns - ------- - dict: - The targets and the masks separated into their own items. The targets are a list of - 3 channel, 4 dimensional :class:`numpy.ndarray` objects sized for each output from the - model. The masks are a :class:`numpy.ndarray` of the final output size. Any additional - masks(e.g. eye and mouth masks) will be collated together into a :class:`numpy.ndarray` - of the final output size. The number of channels will be the number of additional - masks available - """ - logger.trace("target_batch shapes: %s", [tgt.shape for tgt in target_batch]) - retval = dict(targets=[batch[..., :3] for batch in target_batch], - masks=target_batch[-1][..., 3][..., None]) - if target_batch[-1].shape[-1] > 4: - retval["additional_masks"] = target_batch[-1][..., 4:] - logger.trace("returning: %s", {k: v.shape if isinstance(v, np.ndarray) else [tgt.shape - for tgt in v] - for k, v in retval.items()}) + grids = np.mgrid[0: p_mx: complex(self._processing_size), # type: ignore + 0: p_mx: complex(self._processing_size)] # type: ignore + retval = AugConstants(clahe_base_contrast=max(2, self._processing_size // 128), + clahe_chance=int(self._config.get("color_clahe_chance", 50)) / 100, + clahe_max_size=int(self._config.get("color_clahe_max_size", 4)), + lab_adjust=lab_adjust, + transform_rotation=int(self._config.get("rotation_range", 10)), + transform_zoom=int(self._config.get("zoom_amount", 5)) / 100, + transform_shift=tform_shift, + warp_maps=np.stack((warp_mapx, warp_mapy), axis=1), + warp_pad=(warp_pad, warp_pad), + warp_slices=slice(warp_pad // 10, -warp_pad // 10), + warp_lm_edge_anchors=edge_anchors, + warp_lm_grids=grids) + logger.debug("Initialized constants: %s", retval) return retval # <<< COLOR AUGMENTATION >>> # - def color_adjust(self, batch): + def color_adjust(self, batch: np.ndarray) -> np.ndarray: """ Perform color augmentation on the passed in batch. The color adjustment parameters are set in :file:`config.train.ini` @@ -219,49 +160,44 @@ def color_adjust(self, batch): A 4-dimensional array of the same shape as :attr:`batch` with color augmentation applied. """ - if not self.is_display: - logger.trace("Augmenting color") - batch = batch_convert_color(batch, "BGR2LAB") - batch = self._random_clahe(batch) - batch = self._random_lab(batch) - batch = batch_convert_color(batch, "LAB2BGR") + logger.trace("Augmenting color") # type: ignore + batch = batch_convert_color(batch, "BGR2LAB") + self._random_lab(batch) + self._random_clahe(batch) + batch = batch_convert_color(batch, "LAB2BGR") return batch - def _random_clahe(self, batch): + def _random_clahe(self, batch: np.ndarray) -> None: """ Randomly perform Contrast Limited Adaptive Histogram Equalization on a batch of images """ - base_contrast = self._constants["clahe_base_contrast"] + base_contrast = self._constants.clahe_base_contrast batch_random = np.random.rand(self._batchsize) - indices = np.where(batch_random < self._config.get("color_clahe_chance", 50) / 100)[0] + indices = np.where(batch_random < self._constants.clahe_chance)[0] if not np.any(indices): - return batch - - grid_bases = np.rint(np.random.uniform(0, - self._config.get("color_clahe_max_size", 4), - size=indices.shape[0])).astype("uint8") - contrast_adjustment = (grid_bases * (base_contrast // 2)) - grid_sizes = contrast_adjustment + base_contrast - logger.trace("Adjusting Contrast. Grid Sizes: %s", grid_sizes) + return + grid_bases = np.random.randint(self._constants.clahe_max_size + 1, + size=indices.shape[0], + dtype="uint8") + grid_sizes = (grid_bases * (base_contrast // 2)) + base_contrast + logger.trace("Adjusting Contrast. Grid Sizes: %s", grid_sizes) # type: ignore clahes = [cv2.createCLAHE(clipLimit=2.0, # pylint: disable=no-member tileGridSize=(grid_size, grid_size)) for grid_size in grid_sizes] for idx, clahe in zip(indices, clahes): - batch[idx, :, :, 0] = clahe.apply(batch[idx, :, :, 0]) - return batch + batch[idx, :, :, 0] = clahe.apply(batch[idx, :, :, 0], ) - def _random_lab(self, batch): + def _random_lab(self, batch: np.ndarray) -> None: """ Perform random color/lightness adjustment in L*a*b* color space on a batch of images """ - amount_l = self._config.get("color_lightness", 30) / 100 - amount_ab = self._config.get("color_ab", 8) / 100 - adjust = np.array([amount_l, amount_ab, amount_ab], dtype="float32") - randoms = ( - (np.random.rand(self._batchsize, 1, 1, 3).astype("float32") * (adjust * 2)) - adjust) - logger.trace("Random LAB adjustments: %s", randoms) - + randoms = np.random.uniform(-self._constants.lab_adjust, + self._constants.lab_adjust, + size=(self._batchsize, 1, 1, 3)).astype("float32") + logger.trace("Random LAB adjustments: %s", randoms) # type: ignore + # Iterating through the images and channels is much faster than numpy.where and slightly + # faster than numexpr.where. for image, rand in zip(batch, randoms): for idx in range(rand.shape[-1]): adjustment = rand[:, :, idx] @@ -269,10 +205,9 @@ def _random_lab(self, batch): image[:, :, idx] = ((255 - image[:, :, idx]) * adjustment) + image[:, :, idx] else: image[:, :, idx] = image[:, :, idx] * (1 + adjustment) - return batch # <<< IMAGE AUGMENTATION >>> # - def transform(self, batch): + def transform(self, batch: np.ndarray): """ Perform random transformation on the passed in batch. The transformation parameters are set in :file:`config.train.ini` @@ -282,47 +217,36 @@ def transform(self, batch): batch: :class:`numpy.ndarray` The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, `channels`) and in `BGR` format. - - Returns - ---------- - :class:`numpy.ndarray` - A 4-dimensional array of the same shape as :attr:`batch` with transformation applied. """ - if self.is_display: - return batch - logger.trace("Randomly transforming image") - rotation_range = self._config.get("rotation_range", 10) - zoom_range = self._config.get("zoom_amount", 5) / 100 - shift_range = self._config.get("shift_range", 5) / 100 - - rotation = np.random.uniform(-rotation_range, - rotation_range, + logger.trace("Randomly transforming image") # type: ignore + + rotation = np.random.uniform(-self._constants.transform_rotation, + self._constants.transform_rotation, size=self._batchsize).astype("float32") - scale = np.random.uniform(1 - zoom_range, - 1 + zoom_range, + scale = np.random.uniform(1 - self._constants.transform_zoom, + 1 + self._constants.transform_zoom, size=self._batchsize).astype("float32") - tform = np.random.uniform( - -shift_range, - shift_range, - size=(self._batchsize, 2)).astype("float32") * self._training_size + tform = np.random.uniform(-self._constants.transform_shift, + self._constants.transform_shift, + size=(self._batchsize, 2)).astype("float32") mats = np.array( - [cv2.getRotationMatrix2D((self._training_size // 2, self._training_size // 2), + [cv2.getRotationMatrix2D((self._processing_size // 2, self._processing_size // 2), rot, scl) for rot, scl in zip(rotation, scale)]).astype("float32") mats[..., 2] += tform - batch = np.array([cv2.warpAffine(image, - mat, - (self._training_size, self._training_size), - borderMode=cv2.BORDER_REPLICATE) - for image, mat in zip(batch, mats)]) + for image, mat in zip(batch, mats): + cv2.warpAffine(image, + mat, + (self._processing_size, self._processing_size), + dst=image, + borderMode=cv2.BORDER_REPLICATE) - logger.trace("Randomly transformed image") - return batch + logger.trace("Randomly transformed image") # type: ignore - def random_flip(self, batch): + def random_flip(self, batch: np.ndarray): """ Perform random horizontal flipping on the passed in batch. The probability of flipping an image is set in :file:`config.train.ini` @@ -332,21 +256,15 @@ def random_flip(self, batch): batch: :class:`numpy.ndarray` The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, `channels`) and in `BGR` format. - - Returns - ---------- - :class:`numpy.ndarray` - A 4-dimensional array of the same shape as :attr:`batch` with transformation applied. """ - if not self.is_display: - logger.trace("Randomly flipping image") - randoms = np.random.rand(self._batchsize) - indices = np.where(randoms > self._config.get("random_flip", 50) / 100)[0] - batch[indices] = batch[indices, :, ::-1] - logger.trace("Randomly flipped %s images of %s", len(indices), self._batchsize) - return batch - - def warp(self, batch, to_landmarks=False, **kwargs): + logger.trace("Randomly flipping image") # type: ignore + randoms = np.random.rand(self._batchsize) + indices = np.where(randoms > int(self._config.get("random_flip", 50)) / 100)[0] + batch[indices] = batch[indices, :, ::-1] + logger.trace("Randomly flipped %s images of %s", # type: ignore + len(indices), self._batchsize) + + def warp(self, batch: np.ndarray, to_landmarks: bool = False, **kwargs) -> np.ndarray: """ Perform random warping on the passed in batch by one of two methods. Parameters @@ -367,43 +285,71 @@ def warp(self, batch, to_landmarks=False, **kwargs): * **batch_dst_points** (:class:`numpy.ndarray`) - A batch of randomly chosen closest \ match destination faces landmarks. This is a 3-dimensional array in the shape \ (`batchsize`, `68`, `2`). + Returns ---------- :class:`numpy.ndarray` A 4-dimensional array of the same shape as :attr:`batch` with warping applied. """ if to_landmarks: - return self._random_warp_landmarks(batch, **kwargs).astype("float32") / 255.0 - return self._random_warp(batch).astype("float32") / 255.0 + return self._random_warp_landmarks(batch, **kwargs) + return self._random_warp(batch) - def _random_warp(self, batch): - """ Randomly warp the input batch """ - logger.trace("Randomly warping batch") - mapx = self._constants["warp_mapx"] - mapy = self._constants["warp_mapy"] - pad = self._constants["warp_pad"] - slices = self._constants["warp_slices"] + def _random_warp(self, batch: np.ndarray) -> np.ndarray: + """ Randomly warp the input batch + Parameters + ---------- + batch: :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + logger.trace("Randomly warping batch") # type: ignore + slices = self._constants.warp_slices rands = np.random.normal(size=(self._batchsize, 2, 5, 5), - scale=self._scale).astype("float32") - batch_maps = np.stack((mapx, mapy), axis=1) + rands - batch_interp = np.array([[cv2.resize(map_, (pad, pad))[slices, slices] for map_ in maps] + scale=self._warp_scale).astype("float32") + batch_maps = ne.evaluate("m + r", local_dict=dict(m=self._constants.warp_maps, r=rands)) + batch_interp = np.array([[cv2.resize(map_, self._constants.warp_pad)[slices, slices] + for map_ in maps] for maps in batch_maps]) warped_batch = np.array([cv2.remap(image, interp[0], interp[1], cv2.INTER_LINEAR) for image, interp in zip(batch, batch_interp)]) - logger.trace("Warped image shape: %s", warped_batch.shape) + logger.trace("Warped image shape: %s", warped_batch.shape) # type: ignore return warped_batch - def _random_warp_landmarks(self, batch, batch_src_points, batch_dst_points): - """ From dfaker. Warp the image to a similar set of landmarks from the opposite side """ - logger.trace("Randomly warping landmarks") - edge_anchors = self._constants["warp_lm_edge_anchors"] - grids = self._constants["warp_lm_grids"] - slices = self._constants["tgt_slices"] + def _random_warp_landmarks(self, + batch: np.ndarray, + batch_src_points: np.ndarray, + batch_dst_points: np.ndarray) -> np.ndarray: + """ From dfaker. Warp the image to a similar set of landmarks from the opposite side + + batch: :class:`numpy.ndarray` + The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, + `3`) and in `BGR` format. + batch_src_points :class:`numpy.ndarray` + A batch of 68 point landmarks for the source faces. This is a 3-dimensional array in + the shape (`batchsize`, `68`, `2`). + batch_dst_points :class:`numpy.ndarray` + A batch of randomly chosen closest match destination faces landmarks. This is a + 3-dimensional array in the shape (`batchsize`, `68`, `2`). + + Returns + ---------- + :class:`numpy.ndarray` + A 4-dimensional array of the same shape as :attr:`batch` with warping applied. + """ + logger.trace("Randomly warping landmarks") # type: ignore + edge_anchors = self._constants.warp_lm_edge_anchors + grids = self._constants.warp_lm_grids batch_dst = (batch_dst_points + np.random.normal(size=batch_dst_points.shape, - scale=2.0)) + scale=self._warp_lm_scale)) face_cores = [cv2.convexHull(np.concatenate([src[17:], dst[17:]], axis=0)) for src, dst in zip(batch_src_points.astype("int32"), @@ -418,14 +364,14 @@ def _random_warp_landmarks(self, batch, batch_src_points, batch_dst_points): for src, dst, face_core in zip(batch_src[:, :18, :], batch_dst[:, :18, :], face_cores)] - batch_src = [np.delete(src, idxs, axis=0) for idxs, src in zip(rem_indices, batch_src)] - batch_dst = [np.delete(dst, idxs, axis=0) for idxs, dst in zip(rem_indices, batch_dst)] + lbatch_src = [np.delete(src, idxs, axis=0) for idxs, src in zip(rem_indices, batch_src)] + lbatch_dst = [np.delete(dst, idxs, axis=0) for idxs, dst in zip(rem_indices, batch_dst)] grid_z = np.array([griddata(dst, src, (grids[0], grids[1]), method="linear") - for src, dst in zip(batch_src, batch_dst)]) + for src, dst in zip(lbatch_src, lbatch_dst)]) maps = grid_z.reshape((self._batchsize, - self._training_size, - self._training_size, + self._processing_size, + self._processing_size, 2)).astype("float32") warped_batch = np.array([cv2.remap(image, map_[..., 1], @@ -433,33 +379,5 @@ def _random_warp_landmarks(self, batch, batch_src_points, batch_dst_points): cv2.INTER_LINEAR, cv2.BORDER_TRANSPARENT) for image, map_ in zip(batch, maps)]) - warped_batch = np.array([cv2.resize(image[slices, slices, :], - (self._input_size, self._input_size), - cv2.INTER_AREA) - for image in warped_batch]) - logger.trace("Warped batch shape: %s", warped_batch.shape) + logger.trace("Warped batch shape: %s", warped_batch.shape) # type: ignore return warped_batch - - def skip_warp(self, batch): - """ Returns the images resized and cropped for feeding the model, if warping has been - disabled. - - Parameters - ---------- - batch: :class:`numpy.ndarray` - The batch should be a 4-dimensional array of shape (`batchsize`, `height`, `width`, - `3`) and in `BGR` format. - - Returns - ------- - :class:`numpy.ndarray` - The given batch cropped and resized for feeding the model - """ - logger.trace("Compiling skip warp images: batch shape: %s", batch.shape) - slices = self._constants["tgt_slices"] - retval = np.array([cv2.resize(image[slices, slices, :], - (self._input_size, self._input_size), - cv2.INTER_AREA) - for image in batch], dtype='float32') / 255. - logger.trace("feed batch shape: %s", retval.shape) - return retval diff --git a/lib/training/cache.py b/lib/training/cache.py new file mode 100644 index 0000000000..94014ef410 --- /dev/null +++ b/lib/training/cache.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +""" Holds the data cache for training data generators """ +import logging +import os +import sys + +from threading import Lock +from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING + +import cv2 +import numpy as np +from tqdm import tqdm + +from lib.align import DetectedFace +from lib.align.aligned_face import CenteringType +from lib.image import read_image_batch, read_image_meta_batch +from lib.utils import FaceswapError + +if sys.version_info < (3, 8): + from typing_extensions import get_args, Literal +else: + from typing import get_args, Literal + +if TYPE_CHECKING: + from .generator import ConfigType + from lib.align.alignments import PNGHeaderAlignmentsDict, PNGHeaderDict + +logger = logging.getLogger(__name__) + +_FACE_CACHES: Dict[str, "_Cache"] = {} + + +def get_cache(side: Literal["a", "b"], + filenames: Optional[List[str]] = None, + config: Optional["ConfigType"] = None, + size: Optional[int] = None, + coverage_ratio: Optional[float] = None) -> "_Cache": + """ Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then + create it. + + Parameters + ---------- + side: str + `"a"` or `"b"`. The side of the model to obtain the cache for + filenames: list + The filenames of all the images. This can either be the full path or the base name. If the + full paths are passed in, they are stripped to base name for use as the cache key. Must be + passed for the first call of this function for each side. For subsequent calls this + parameter is ignored. Default: ``None`` + config: dict, optional + The user selected training configuration options. Must be passed for the first call of this + function for each side. For subsequent calls this parameter is ignored. Default: ``None`` + size: int, optional + The largest output size of the model. Must be passed for the first call of this function + for each side. For subsequent calls this parameter is ignored. Default: ``None`` + coverage_ratio: float: optional + The coverage ratio that the model is using. Must be passed for the first call of this + function for each side. For subsequent calls this parameter is ignored. Default: ``None`` + + Returns + ------- + :class:`_Cache` + The face meta information cache for the requested side + """ + if not _FACE_CACHES.get(side): + assert config is not None, ("config must be provided for first call to cache") + assert filenames is not None, ("filenames must be provided for first call to cache") + assert size is not None, ("size must be provided for first call to cache") + assert coverage_ratio is not None, ("coverage_ratio must be provided for first call to " + "cache") + logger.debug("Creating cache. side: %s, size: %s, coverage_ratio: %s", + side, size, coverage_ratio) + _FACE_CACHES[side] = _Cache(filenames, config, size, coverage_ratio) + return _FACE_CACHES[side] + + +def _check_reset(face_cache: "_Cache") -> bool: + """ Check whether a given cache needs to be reset because a face centering change has been + detected in the other cache. + + Parameters + ---------- + face_cache: :class:`_Cache` + The cache object that is checking whether it should reset + + Returns + ------- + bool + ``True`` if the given object should reset the cache, otherwise ``False`` + """ + check_cache = next((cache for cache in _FACE_CACHES.values() if cache != face_cache), None) + retval = False if check_cache is None else check_cache.check_reset() + return retval + + +class _Cache(): + """ A thread safe mechanism for collecting and holding face meta information (masks, " + "alignments data etc.) for multiple :class:`TrainingDataGenerator`s. + + Each side may have up to 3 generators (training, preview and time-lapse). To conserve VRAM + these need to share access to the same face information for the images they are processing. + + As the cache is populated at run-time, thread safe writes are required for the first epoch. + Following that, the cache is only used for reads, which is thread safe intrinsically. + + It would probably be quicker to set locks on each individual face, but for code complexity + reasons, and the fact that the lock is only taken up during cache population, and it should + only be being read multiple times on save iterations, we lock the whole cache during writes. + + Parameters + ---------- + filenames: list + The filenames of all the images. This can either be the full path or the base name. If the + full paths are passed in, they are stripped to base name for use as the cache key. + config: dict + The user selected training configuration options + size: int + The largest output size of the model + coverage_ratio: float + The coverage ratio that the model is using. + """ + def __init__(self, + filenames: List[str], + config: "ConfigType", + size: int, + coverage_ratio: float) -> None: + logger.debug("Initializing: %s (filenames: %s, size: %s, coverage_ratio: %s)", + self.__class__.__name__, len(filenames), size, coverage_ratio) + self._lock = Lock() + self._cache_info = dict(cache_full=False, has_reset=False) + self._partially_loaded: List[str] = [] + + self._image_count = len(filenames) + self._cache: Dict[str, DetectedFace] = {} + self._aligned_landmarks: Dict[str, np.ndarray] = {} + self._extract_version = 0.0 + self._size = size + + assert config["centering"] in get_args(CenteringType) + self._centering: CenteringType = cast(CenteringType, config["centering"]) + self._config = config + self._coverage_ratio = coverage_ratio + + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def cache_full(self) -> bool: + """bool: ``True`` if the cache has been fully populated. ``False`` if there are items still + to be cached. """ + if self._cache_info["cache_full"]: + return self._cache_info["cache_full"] + with self._lock: + return self._cache_info["cache_full"] + + @property + def aligned_landmarks(self) -> Dict[str, np.ndarray]: + """ dict: The filename as key, aligned landmarks as value. """ + # Note: Aligned landmarks are only used for warp-to-landmarks, so this can safely populate + # all of the aligned landmarks for the entire cache. + if not self._aligned_landmarks: + with self._lock: + # For Warp-To-Landmarks a race condition can occur where this is referenced from + # the opposite side prior to it being populated, so block on a lock. + self._aligned_landmarks = {key: face.aligned.landmarks + for key, face in self._cache.items()} + return self._aligned_landmarks + + @property + def size(self) -> int: + """ int: The pixel size of the cropped aligned face """ + return self._size + + def check_reset(self) -> bool: + """ Check whether this cache has been reset due to a face centering change, and reset the + flag if it has. + + Returns + ------- + bool + ``True`` if the cache has been reset because of a face centering change due to + legacy alignments, otherwise ``False``. """ + retval = self._cache_info["has_reset"] + if retval: + logger.debug("Resetting 'has_reset' flag") + self._cache_info["has_reset"] = False + return retval + + def get_items(self, filenames: List[str]) -> List[DetectedFace]: + """ Obtain the cached items for a list of filenames. The returned list is in the same order + as the provided filenames. + + Parameters + ---------- + filenames: list + A list of image filenames to obtain the cached data for + + Returns + ------- + list + List of DetectedFace objects holding the cached metadata. The list returns in the same + order as the filenames received + """ + return [self._cache[os.path.basename(filename)] for filename in filenames] + + def cache_metadata(self, filenames: List[str]) -> np.ndarray: + """ Obtain the batch with metadata for items that need caching and cache DetectedFace + objects to :attr:`_cache`. + + Parameters + ---------- + filenames: list + List of full paths to image file names + + Returns + ------- + :class:`numpy.ndarray` + The batch of face images loaded from disk + """ + keys = [os.path.basename(filename) for filename in filenames] + with self._lock: + if _check_reset(self): + self._reset_cache(False) + + needs_cache = [filename + for filename, key in zip(filenames, keys) + if key not in self._cache or key in self._partially_loaded] + logger.trace("Needs cache: %s", needs_cache) # type: ignore + + if not needs_cache: + # Don't bother reading the metadata if no images in this batch need caching + logger.debug("All metadata already cached for: %s", keys) + return read_image_batch(filenames) + + batch, metadata = read_image_batch(filenames, with_metadata=True) + + if len(batch.shape) == 1: + folder = os.path.dirname(filenames[0]) + details = [ + f"{key} ({f'{img.shape[1]}px' if isinstance(img, np.ndarray) else type(img)})" + for key, img in zip(keys, batch)] + msg = (f"There are mismatched image sizes in the folder '{folder}'. All training " + "images for each side must have the same dimensions.\nThe batch that " + f"failed contains the following files:\n{details}.") + raise FaceswapError(msg) + + # Populate items into cache + for filename in needs_cache: + key = os.path.basename(filename) + meta = metadata[filenames.index(filename)] + + # Version Check + self._validate_version(meta, filename) + if self._partially_loaded: # Faces already loaded for Warp-to-landmarks + self._partially_loaded.remove(key) + detected_face = self._cache[key] + else: + detected_face = self._load_detected_face(filename, meta["alignments"]) + + self._prepare_masks(filename, detected_face) + self._cache[key] = detected_face + + # Update the :attr:`cache_full` attribute + cache_full = not self._partially_loaded and len(self._cache) == self._image_count + if cache_full: + logger.verbose("Cache filled: '%s'", os.path.dirname(filenames[0])) # type: ignore + self._cache_info["cache_full"] = cache_full + + return batch + + def pre_fill(self, filenames: List[str], side: Literal["a", "b"]) -> None: + """ When warp to landmarks is enabled, the cache must be pre-filled, as each side needs + access to the other side's alignments. + + Parameters + ---------- + filenames: list + The list of full paths to the images to load the metadata from + side: str + `"a"` or `"b"`. The side of the model being cached. Used for info output + """ + with self._lock: + for filename, meta in tqdm(read_image_meta_batch(filenames), + desc=f"WTL: Caching Landmarks ({side.upper()})", + total=len(filenames), + leave=False): + if "itxt" not in meta or "alignments" not in meta["itxt"]: + raise FaceswapError(f"Invalid face image found. Aborting: '{filename}'") + + meta = meta["itxt"] + key = os.path.basename(filename) + # Version Check + self._validate_version(meta, filename) + detected_face = self._load_detected_face(filename, meta["alignments"]) + self._cache[key] = detected_face + self._partially_loaded.append(key) + + def _validate_version(self, png_meta: "PNGHeaderDict", filename: str) -> None: + """ Validate that there are not a mix of v1.0 extracted faces and v2.x faces. + + Parameters + ---------- + png_meta: dict + The information held within the Faceswap PNG Header + filename: str + The full path to the file being validated + + Raises + ------ + FaceswapError + If a version 1.0 face appears in a 2.x set or vice versa + """ + alignment_version = png_meta["source"]["alignments_version"] + + if not self._extract_version: + logger.debug("Setting initial extract version: %s", alignment_version) + self._extract_version = alignment_version + if alignment_version == 1.0 and self._centering != "legacy": + self._reset_cache(True) + return + + if (self._extract_version == 1.0 and alignment_version > 1.0) or ( + alignment_version == 1.0 and self._extract_version > 1.0): + raise FaceswapError("Mixing legacy and full head extracted facesets is not supported. " + "The following folder contains a mix of extracted face types: " + f"'{os.path.dirname(filename)}'") + + self._extract_version = min(alignment_version, self._extract_version) + + def _reset_cache(self, set_flag: bool) -> None: + """ In the event that a legacy extracted face has been seen, and centering is not legacy + the cache will need to be reset for legacy centering. + + Parameters + ---------- + set_flag: bool + ``True`` if the flag should be set to indicate that the cache is being reset because of + a legacy face set/centering mismatch. ``False`` if the cache is being reset because it + has detected a reset flag from the opposite cache. + """ + if set_flag: + logger.warning("You are using legacy extracted faces but have selected '%s' centering " + "which is incompatible. Switching centering to 'legacy'", + self._centering) + self._config["centering"] = "legacy" + self._centering = "legacy" + self._cache = {} + self._cache_info["cache_full"] = False + if set_flag: + self._cache_info["has_reset"] = True + + def _load_detected_face(self, + filename: str, + alignments: "PNGHeaderAlignmentsDict") -> DetectedFace: + """ Load a :class:`DetectedFace` object and load its associated `aligned` property. + + Parameters + ---------- + filename: str + The file path for the current image + alignments: dict + The alignments for a single face, extracted from a PNG header + + Returns + ------- + :class:`lib.align.DetectedFace` + The loaded Detected Face object + """ + detected_face = DetectedFace() + detected_face.from_png_meta(alignments) + detected_face.load_aligned(None, + size=self._size, + centering=self._centering, + coverage_ratio=self._coverage_ratio, + is_aligned=True, + is_legacy=self._extract_version == 1.0) + logger.trace("Cached aligned face for: %s", filename) # type: ignore + return detected_face + + def _prepare_masks(self, filename: str, detected_face: DetectedFace) -> None: + """ Prepare the masks required from training, and compile into a single compressed array + + Parameters + ---------- + filename: str + The file path for the current image + detected_face: :class:`lib.align.DetectedFace` + The detected face object that holds the masks + """ + masks = [(self._get_face_mask(filename, detected_face))] + for area in get_args(Literal["eye", "mouth"]): + masks.append(self._get_localized_mask(filename, detected_face, area)) + + detected_face.store_training_masks(masks, delete_masks=True) + logger.trace("Stored masks for filename: %s)", filename) # type: ignore + + def _get_face_mask(self, filename: str, detected_face: DetectedFace) -> Optional[np.ndarray]: + """ Obtain the training sized face mask from the :class:`DetectedFace` for the requested + mask type. + + Parameters + ---------- + filename: str + The file path for the current image + detected_face: :class:`lib.align.DetectedFace` + The detected face object that holds the masks + + Raises + ------ + FaceswapError + If the requested mask type is not available an error is returned along with a list + of available masks + """ + if not self._config["penalized_mask_loss"] and not self._config["learn_mask"]: + return None + + if not self._config["mask_type"]: + logger.debug("No mask selected. Not validating") + return None + + if self._config["mask_type"] not in detected_face.mask: + raise FaceswapError( + f"You have selected the mask type '{self._config['mask_type']}' but at least one " + "face does not contain the selected mask.\n" + f"The face that failed was: '{filename}'\n" + f"The masks that exist for this face are: {list(detected_face.mask)}") + + mask = detected_face.mask[str(self._config["mask_type"])] + mask.set_blur_and_threshold(blur_kernel=int(self._config["mask_blur_kernel"]), + threshold=int(self._config["mask_threshold"])) + + pose = detected_face.aligned.pose + mask.set_sub_crop(pose.offset[mask.stored_centering], + pose.offset[self._centering], + self._centering, + self._coverage_ratio) + face_mask = mask.mask + if self._size != face_mask.shape[0]: + interpolator = cv2.INTER_CUBIC if mask.stored_size < self._size else cv2.INTER_AREA + face_mask = cv2.resize(face_mask, + (self._size, self._size), + interpolation=interpolator)[..., None] + + logger.trace("Obtained face mask for: %s %s", filename, face_mask.shape) # type: ignore + return face_mask + + def _get_localized_mask(self, + filename: str, + detected_face: DetectedFace, + area: Literal["eye", "mouth"]) -> Optional[np.ndarray]: + """ Obtain a localized mask for the given area if it is required for training. + + Parameters + ---------- + filename: str + The file path for the current image + detected_face: :class:`lib.align.DetectedFace` + The detected face object that holds the masks + area: str + `"eye"` or `"mouth"`. The area of the face to obtain the mask for + """ + if not self._config["penalized_mask_loss"] or int(self._config[f"{area}_multiplier"]) <= 1: + return None + mask = detected_face.get_landmark_mask(area, self._size // 16, self._size // 32) + logger.trace("Caching localized '%s' mask for: %s %s", # type: ignore + area, filename, mask.shape) + return mask + + +class RingBuffer(): # pylint: disable=too-few-public-methods + """ Rolling buffer for holding training/preview batches + + Parameters + ---------- + batch_size: int + The batch size to create the buffer for + image_shape: tuple + The height/width/channels shape of a single image in the batch + buffer_size: int, optional + The number of arrays to hold in the rolling buffer. Default: `2` + dtype: str, optional + The datatype to create the buffer as. Default: `"uint8"` + """ + def __init__(self, + batch_size: int, + image_shape: Tuple[int, int, int], + buffer_size: int = 2, + dtype: str = "uint8") -> None: + logger.debug("Initializing: %s (batch_size: %s, image_shape: %s, buffer_size: %s, " + "dtype: %s", self.__class__.__name__, batch_size, image_shape, buffer_size, + dtype) + self._max_index = buffer_size - 1 + self._index = 0 + self._buffer = [np.empty((batch_size, *image_shape), dtype=dtype) + for _ in range(buffer_size)] + logger.debug("Initialized: %s", self.__class__.__name__) # type: ignore + + def __call__(self) -> np.ndarray: + """ Obtain the next array from the ring buffer + + Returns + ------- + :class:`np.ndarray` + A pre-allocated numpy array from the buffer + """ + retval = self._buffer[self._index] + self._index += 1 if self._index < self._max_index else -self._max_index + return retval diff --git a/lib/training/generator.py b/lib/training/generator.py index 22b668b994..ad0ae0ecb8 100644 --- a/lib/training/generator.py +++ b/lib/training/generator.py @@ -3,179 +3,207 @@ import logging import os +import sys +from concurrent import futures from random import shuffle, choice -from threading import Lock -from zlib import decompress +from typing import cast, Dict, Generator, List, Tuple, TYPE_CHECKING, Union -import numpy as np import cv2 -from tqdm import tqdm -from lib.align import AlignedFace, DetectedFace, get_centered_size -from lib.image import read_image_batch, read_image_meta_batch +import numpy as np +import numexpr as ne +from lib.align import AlignedFace, DetectedFace +from lib.align.aligned_face import CenteringType +from lib.image import read_image_batch from lib.multithreading import BackgroundGenerator from lib.utils import FaceswapError from . import ImageAugmentation +from .cache import get_cache, RingBuffer -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -_FACE_CACHES = dict() - - -def _get_cache(side, filenames, config): - """ Obtain a :class:`_Cache` object for the given side. If the object does not pre-exist then - create it. - - Parameters - ---------- - side: str - `"a"` or `"b"`. The side of the model to obtain the cache for - filenames: list - The filenames of all the images. This can either be the full path or the base name. If the - full paths are passed in, they are stripped to base name for use as the cache key. - config: dict - The user selected training configuration options - - Returns - ------- - :class:`_Cache` - The face meta information cache for the requested side - """ - if not _FACE_CACHES.get(side): - logger.debug("Creating cache. Side: %s", side) - _FACE_CACHES[side] = _Cache(filenames, config) - return _FACE_CACHES[side] - - -def _check_reset(face_cache): - """ Check whether a given cache needs to be reset because a face centering change has been - detected in the other cache. - - Parameters - ---------- - face_cache: :class:`_Cache` - The cache object that is checking whether it should reset - - Returns - ------- - bool - ``True`` if the given object should reset the cache, otherwise ``False`` - """ - check_cache = next((cache for cache in _FACE_CACHES.values() if cache != face_cache), None) - retval = check_cache if check_cache is None else check_cache.check_reset() - return retval +if sys.version_info < (3, 8): + from typing_extensions import get_args, Literal +else: + from typing import get_args, Literal +if TYPE_CHECKING: + from plugins.train.model._base import ModelBase + from .cache import _Cache -class _Cache(): - """ A thread safe mechanism for collecting and holding face meta information (masks, " - "alignments data etc.) for multiple :class:`TrainingDataGenerator`s. +logger = logging.getLogger(__name__) +ConfigType = Dict[str, Union[bool, int, float, str]] # TODO Dataclass +BatchType = Tuple[np.ndarray, List[np.ndarray]] - Each side may have up to 3 generators (training, preview and time-lapse). To conserve VRAM - these need to share access to the same face information for the images they are processing. - As the cache is populated at run-time, thread safe writes are required for the first epoch. - Following that, the cache is only used for reads, which is thread safe intrinsically. +class DataGenerator(): + """ Parent class for Training and Preview Data Generators. - It would probably be quicker to set locks on each individual face, but for code complexity - reasons, and the fact that the lock is only taken up during cache population, and it should - only be being read multiple times on save iterations, we lock the whole cache during writes. + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles augmented data, target data and sample data. Parameters ---------- - filenames: list - The filenames of all the images. This can either be the full path or the base name. If the - full paths are passed in, they are stripped to base name for use as the cache key. + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding config: dict - The user selected training configuration options + The configuration `dict` generated from :file:`config.train.ini` containing the trainer + plugin configuration options. + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final augmented data from. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. """ - def __init__(self, filenames, config): - self._lock = Lock() - self._cache = {os.path.basename(filename): dict(cached=False) for filename in filenames} - self._aligned_landmarks = None - self._partial_load = False - self._cache_full = False - self._extract_version = None - self._has_reset = False - self._size = None - - self._centering = config["centering"] + def __init__(self, + config: ConfigType, + model: "ModelBase", + side: Literal["a", "b"], + images: List[str], + batch_size: int) -> None: + logger.debug("Initializing %s: (model: %s, side: %s, images: %s , " # type: ignore + "batch_size: %s, config: %s)", self.__class__.__name__, model.name, side, + len(images), batch_size, config) self._config = config + self._side = side + self._images = images + self._batch_size = batch_size + self._process_size = max([model.input_shape[1]] + [img[1] + for img in model.output_shapes[0]]) + self._output_sizes = [shape[0] for shape in model.output_shapes[0] if shape[-1] != 1] + self._coverage_ratio = model.coverage_ratio + self._color_order = model.color_order.lower() + self._use_mask = self._config["mask_type"] and (self._config["penalized_mask_loss"] or + self._config["learn_mask"]) + + self._validate_samples() + self._buffer = RingBuffer(batch_size, + (self._process_size, self._process_size, self._total_channels), + dtype="uint8") + self._face_cache: "_Cache" = get_cache(side, + filenames=images, + config=self._config, + size=self._process_size, + coverage_ratio=self._coverage_ratio) + logger.debug("Initialized %s", self.__class__.__name__) @property - def cache_full(self): - """bool: ``True`` if the cache has been fully populated. ``False`` if there are items still - to be cached. """ - if self._cache_full: - return self._cache_full - with self._lock: - return self._cache_full - - @property - def partially_loaded(self): - """ bool: ``True`` if the cache has been partially loaded for Warp To Landmarks otherwise - ``False`` """ - if self._partial_load: - return self._partial_load - with self._lock: - return self._partial_load + def _total_channels(self) -> int: + """int: The total number of channels, including mask channels that the target image + should hold. """ + channels = 3 + if self._config["mask_type"] and (self._config["learn_mask"] or + self._config["penalized_mask_loss"]): + channels += 1 + + mults = [area for area in ["eye", "mouth"] if int(self._config[f"{area}_multiplier"]) > 1] + if self._config["penalized_mask_loss"] and mults: + channels += len(mults) + return channels + + def minibatch_ab(self, do_shuffle: bool = True) -> Generator[BatchType, None, None]: + """ A Background iterator to return augmented images, samples and targets. - @property - def extract_version(self): - """ float: The alignments file version used to extract the faces. """ - return self._extract_version + The exit point from this class and the sole attribute that should be referenced. Called + from :mod:`plugins.train.trainer._base`. Returns an iterator that yields images for + training, preview and time-lapses. - @property - def aligned_landmarks(self): - """ dict: The filename as key, aligned landmarks as value """ - if self._aligned_landmarks is None: - with self._lock: - # For Warp-To-Landmarks a race condition can occur where this is referenced from - # the opposite side prior to it being populated, so block on a lock. - self._aligned_landmarks = {key: val["aligned_face"].landmarks - for key, val in self._cache.items()} - return self._aligned_landmarks + Parameters + ---------- + do_shuffle: bool, optional + Whether data should be shuffled prior to loading from disk. If true, each time the full + list of filenames are processed, the data will be reshuffled to make sure they are not + returned in the same order. Default: ``True`` - @property - def crop_size(self): - """ int: The pixel size of the cropped aligned face """ - return self._size + Yields + ------ + feed: list + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + logger.debug("do_shuffle: %s", do_shuffle) + args = (do_shuffle, ) + batcher = BackgroundGenerator(self._minibatch, thread_count=1, args=args) + return batcher.iterator() - def check_reset(self): - """ Check whether this cache has been reset due to a face centering change, and reset the - flag if it has. + # << INTERNAL METHODS >> # + def _validate_samples(self) -> None: + """ Ensures that the total number of images within :attr:`images` is greater or equal to + the selected :attr:`batch_size`. - Returns - ------- - bool - ``True`` if the cache has been reset because of a face centering change due to - legacy alignments, otherwise ``False``. """ - retval = self._has_reset - if retval: - logger.debug("Resetting 'has_reset' flag") - self._has_reset = False - return retval + Raises + ------ + :class:`FaceswapError` + If the number of images loaded is smaller than the selected batch size + """ + length = len(self._images) + msg = ("Number of images is lower than batch-size (Note that too few images may lead to " + f"bad training). # images: {length}, batch-size: {self._batch_size}") + try: + assert length >= self._batch_size, msg + except AssertionError as err: + msg += ("\nYou should increase the number of images in your training set or lower " + "your batch-size.") + raise FaceswapError(msg) from err - def get_items(self, filenames): - """ Obtain the cached items for a list of filenames. The returned list is in the same order - as the provided filenames. + def _minibatch(self, do_shuffle: bool) -> Generator[BatchType, None, None]: + """ A generator function that yields the augmented, target and sample images for the + current batch on the current side. Parameters ---------- - filenames: list - A list of image filenames to obtain the cached data for + do_shuffle: bool, optional + Whether data should be shuffled prior to loading from disk. If true, each time the full + list of filenames are processed, the data will be reshuffled to make sure they are not + returned in the same order. Default: ``True`` - Returns - ------- - list - List of dictionaries containing the cached metadata. The list returns in the same order - as the filenames received + Yields + ------ + feed: list + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) """ - return [self._cache[os.path.basename(filename)] for filename in filenames] + logger.debug("Loading minibatch generator: (image_count: %s, do_shuffle: %s)", + len(self._images), do_shuffle) - def cache_metadata(self, filenames): - """ Obtain the batch with metadata for items that need caching and cache them to - :attr:`_cache`. + def _img_iter(imgs): + """ Infinite iterator for recursing through image list and reshuffling at each epoch""" + while True: + if do_shuffle: + shuffle(imgs) + for img in imgs: + yield img + + img_iter = _img_iter(self._images[:]) + while True: + img_paths = [next(img_iter) # pylint:disable=stop-iteration-return + for _ in range(self._batch_size)] + retval = self._process_batch(img_paths) + yield retval + + def _get_images_with_meta(self, filenames: List[str]) -> Tuple[np.ndarray, List[DetectedFace]]: + """ Obtain the raw face images with associated :class:`DetectedFace` objects for this + batch. + + If this is the first time a face has been loaded, then it's meta data is extracted + from the png header and added to :attr:`_face_cache`. Parameters ---------- @@ -184,250 +212,161 @@ def cache_metadata(self, filenames): Returns ------- - :class:`numpy.ndarray` - The batch of face images loaded from disk + raw_faces: :class:`numpy.ndarray` + The full sized batch of training images for the given filenames + list + Batch of :class:`~lib.align.DetectedFace` objects for the given filename including the + aligned face objects for the model output size """ - keys = [os.path.basename(filename) for filename in filenames] - with self._lock: - if _check_reset(self): - self._reset_cache(False) - - needs_cache = [filename - for filename, key in zip(filenames, keys) - if not self._cache[key]["cached"]] - logger.trace("Needs cache: %s", needs_cache) - - if not needs_cache: - # Don't bother reading the metadata if no images in this batch need caching - logger.debug("All metadata already cached for: %s", keys) - return read_image_batch(filenames) - - batch, metadata = read_image_batch(filenames, with_metadata=True) - - if len(batch.shape) == 1: - folder = os.path.dirname(filenames[0]) - details = [ - "{0} ({1})".format( - key, f"{img.shape[1]}px" if isinstance(img, np.ndarray) else type(img)) - for key, img in zip(keys, batch)] - msg = (f"There are mismatched image sizes in the folder '{folder}'. All training " - "images for each side must have the same dimensions.\nThe batch that " - f"failed contains the following files:\n{details}.") - raise FaceswapError(msg) - - # Populate items into cache - for filename in needs_cache: - key = os.path.basename(filename) - meta = metadata[filenames.index(filename)] - - # Version Check - self._validate_version(meta, filename) - if self._partial_load: # Faces already loaded for Warp-to-landmarks - detected_face = self._cache[key]["detected_face"] - else: - detected_face = self._add_aligned_face(filename, - meta["alignments"], - batch.shape[1]) - - self._add_mask(filename, detected_face) - for area in ("eye", "mouth"): - self._add_localized_mask(filename, detected_face, area) - - self._cache[key]["cached"] = True - # Update the :attr:`cache_full` attribute - cache_full = all(item["cached"] for item in self._cache.values()) - if cache_full: - logger.verbose("Cache filled: '%s'", os.path.dirname(filenames[0])) - self._cache_full = cache_full - - return batch - - def pre_fill(self, filenames, side): - """ When warp to landmarks is enabled, the cache must be pre-filled, as each side needs - access to the other side's alignments. + if not self._face_cache.cache_full: + raw_faces = self._face_cache.cache_metadata(filenames) + else: + raw_faces = read_image_batch(filenames) + + detected_faces = self._face_cache.get_items(filenames) + logger.trace("filenames: %s, raw_faces: '%s', detected_faces: %s", # type: ignore + filenames, raw_faces.shape, len(detected_faces)) + return raw_faces, detected_faces + + def _crop_to_coverage(self, + filenames: List[str], + images: np.ndarray, + detected_faces: List[DetectedFace], + batch: np.ndarray) -> None: + """ Crops the training image out of the full extract image based on the centering and + coveage used in the user's configuration settings. + + If legacy extract images are being used then this just returns the extracted batch with + their corresponding landmarks. + + Uses thread pool execution for about a 33% speed increase @ 64 batch size Parameters ---------- filenames: list - The list of full paths to the images to load the metadata from - side: str - `"a"` or `"b"`. The side of the model being cached. Used for info output + The list of filenames that correspond to this batch + images: :class:`numpy.ndarray` + The batch of faces that have been loaded from disk + detected_faces: list + The list of :class:`lib.align.DetectedFace` items corresponding to the batch + batch: :class:`np.ndarray` + The pre-allocated array to hold this batch """ - with self._lock: - for filename, meta in tqdm(read_image_meta_batch(filenames), - desc="WTL: Caching Landmarks ({})".format(side.upper()), - total=len(filenames), - leave=False): - if "itxt" not in meta or "alignments" not in meta["itxt"]: - raise FaceswapError(f"Invalid face image found. Aborting: '{filename}'") - - size = meta["width"] - meta = meta["itxt"] - # Version Check - self._validate_version(meta, filename) - detected_face = self._add_aligned_face(filename, meta["alignments"], size) - self._cache[os.path.basename(filename)]["detected_face"] = detected_face - self._partial_load = True - - def _validate_version(self, png_meta, filename): - """ Validate that there are not a mix of v1.0 extracted faces and v2.x faces. + logger.trace("Cropping training images info: (filenames: %s, side: '%s')", # type: ignore + filenames, self._side) - Parameters - ---------- - png_meta: dict - The information held within the Faceswap PNG Header - filename: str - The full path to the file being validated + with futures.ThreadPoolExecutor() as executor: + proc = {executor.submit(face.aligned.extract_face, img): idx + for idx, (face, img) in enumerate(zip(detected_faces, images))} - Raises - ------ - FaceswapError - If a version 1.0 face appears in a 2.x set or vice versa - """ - alignment_version = png_meta["source"]["alignments_version"] - - if not self._extract_version: - logger.debug("Setting initial extract version: %s", alignment_version) - self._extract_version = alignment_version - if alignment_version == 1.0 and self._centering != "legacy": - self._reset_cache(True) - return + for future in futures.as_completed(proc): + batch[proc[future], ..., :3] = future.result() - if (self._extract_version == 1.0 and alignment_version > 1.0) or ( - alignment_version == 1.0 and self._extract_version > 1.0): - raise FaceswapError("Mixing legacy and full head extracted facesets is not supported. " - "The following folder contains a mix of extracted face types: " - "{}".format(os.path.dirname(filename))) + def _apply_mask(self, detected_faces: List[DetectedFace], batch: np.ndarray) -> None: + """ Applies the masks to the 4th channel of the batch. - self._extract_version = min(alignment_version, self._extract_version) + If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1 + then these masks are applied to the final channels of the batch respectively. - def _reset_cache(self, set_flag): - """ In the event that a legacy extracted face has been seen, and centering is not legacy - the cache will need to be reset for legacy centering. + If masks are not being used then this function returns having done nothing Parameters ---------- - set_flag: bool - ``True`` if the flag should be set to indicate that the cache is being reset because of - a legacy face set/centering mismatch. ``False`` if the cache is being reset because it - has detected a reset flag from the opposite cache. + detected_face: list + The list of :class:`~lib.align.DetectedFace` objects corresponding to the batch + batch: :class:`numpy.ndarray` + The preallocated array to apply masks to + side: str + '"a"' or '"b"' the side that is being processed """ - if set_flag: - logger.warning("You are using legacy extracted faces but have selected '%s' centering " - "which is incompatible. Switching centering to 'legacy'", - self._centering) - self._config["centering"] = "legacy" - self._centering = "legacy" - self._cache = {key: dict(cached=False) for key in self._cache} - self._cache_full = False - self._size = None - if set_flag: - self._has_reset = True - - def _add_aligned_face(self, filename, alignments, image_size): - """ Add a :class:`lib.align.AlignedFace` object to the cache. + if not self._use_mask: + return + + masks = np.array([face.get_training_masks() for face in detected_faces]) + batch[..., 3:] = masks + + logger.trace("side: %s, masks: %s, batch: %s", # type: ignore + self._side, masks.shape, batch.shape) + + def _process_batch(self, filenames: List[str]) -> BatchType: + """ Prepares data for feeding through subclassed methods. + + If this is the first time a face has been loaded, then it's meta data is extracted from the + png header and added to :attr:`_face_cache` Parameters ---------- - filename: str - The file path for the current image - alignments: dict - The alignments for a single face, extracted from a PNG header - image_size: int - The pixel size of the image loaded from disk + filenames: list + List of full paths to image file names for a single batch Returns ------- - :class:`lib.align.DetectedFace` - The Detected Face object that was used to create the Aligned Face + list + 4-dimensional array of faces to feed the training the model. + list + List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) """ - if self._size is None: - self._size = get_centered_size("legacy" if self._extract_version == 1.0 else "head", - self._centering, - image_size) - - detected_face = DetectedFace() - detected_face.from_png_meta(alignments) + raw_faces, detected_faces = self._get_images_with_meta(filenames) + batch = self._buffer() + self._crop_to_coverage(filenames, raw_faces, detected_faces, batch) + self._apply_mask(detected_faces, batch) - aligned_face = AlignedFace(detected_face.landmarks_xy, - centering=self._centering, - size=self._size, - is_aligned=True) - logger.trace("Caching aligned face for: %s", filename) - self._cache[os.path.basename(filename)]["aligned_face"] = aligned_face - return detected_face + return self.process_batch(filenames, raw_faces, detected_faces, batch) - def _add_mask(self, filename, detected_face): - """ Load the mask to the cache if a mask is required for training. + def process_batch(self, + filenames: List[str], + images: np.ndarray, + detected_faces: List[DetectedFace], + batch: np.ndarray) -> BatchType: + """ Override for processing the batch for the current generator. Parameters ---------- - filename: str - The file path for the current image - detected_face: :class:`lib.align.DetectedFace` - The detected face object that holds the masks + filenames: list + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering - Raises - ------ - FaceswapError - If the requested mask type is not available an error is returned along with a list - of available masks + Returns + ------- + list + 4-dimensional array of faces to feed the training the model. + list + List of 4-dimensional :class:`numpy.ndarray`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) """ - if not self._config["penalized_mask_loss"] and not self._config["learn_mask"]: - return - - if not self._config["mask_type"]: - logger.debug("No mask selected. Not validating") - return - - if self._config["mask_type"] not in detected_face.mask: - raise FaceswapError( - "You have selected the mask type '{}' but at least one face does not contain the " - "selected mask.\nThe face that failed was: '{}'\nThe masks that exist for this " - "face are: {}".format( - self._config["mask_type"], filename, list(detected_face.mask))) - - key = os.path.basename(filename) - mask = detected_face.mask[self._config["mask_type"]] - mask.set_blur_and_threshold(blur_kernel=self._config["mask_blur_kernel"], - threshold=self._config["mask_threshold"]) + raise NotImplementedError() - pose = self._cache[key]["aligned_face"].pose - mask.set_sub_crop(pose.offset[mask.stored_centering], - pose.offset[self._centering], - self._centering) + def _set_color_order(self, batch) -> None: + """ Set the color order correctly for the model's input type. - logger.trace("Caching mask for: %s", filename) - self._cache[key]["mask"] = mask + batch: :class:`numpy.ndarray` + The pre-allocated batch with images in the first 3 channels in BGR order + """ + if self._color_order == "rgb": + batch[..., :3] = batch[..., [2, 1, 0]] - def _add_localized_mask(self, filename, detected_face, area): - """ Load a localized mask to the cache for the given area if it is required for training. + def _to_float32(self, in_array: np.ndarray) -> np.ndarray: + """ Cast an UINT8 array in 0-255 range to float32 in 0.0-1.0 range. - Parameters - ---------- - filename: str - The file path for the current image - detected_face: :class:`lib.align.DetectedFace` - The detected face object that holds the masks - area: str - `"eye"` or `"mouth"`. The area of the face to obtain the mask for + in_array: :class:`numpy.ndarray` + The input uint8 array """ - if not self._config["penalized_mask_loss"] or self._config[f"{area}_multiplier"] <= 1: - return - key = "eyes" if area == "eye" else area - - logger.trace("Caching localized '%s' mask for: %s", key, filename) - self._cache[os.path.basename(filename)][f"mask_{key}"] = detected_face.get_landmark_mask( - self._size, - key, - aligned=True, - centering=self._centering, - dilation=self._size // 32, - blur_kernel=self._size // 16, - as_zip=True) + return ne.evaluate("x / c", + local_dict=dict(x=in_array, c=np.float32(255)), + casting="unsafe") -class TrainingDataGenerator(): # pylint:disable=too-few-public-methods +class TrainingDataGenerator(DataGenerator): # pylint:disable=too-few-public-methods """ A Training Data Generator for compiling data for feeding to a model. This class is called from :mod:`plugins.train.trainer._base` and launches a background @@ -435,406 +374,327 @@ class TrainingDataGenerator(): # pylint:disable=too-few-public-methods Parameters ---------- - model_input_size: int - The expected input size for the model. It is assumed that the input to the model is always - a square image. This is the size, in pixels, of the `width` and the `height` of the input - to the model. - model_output_shapes: list - A list of tuples defining the output shapes from the model, in the order that the outputs - are returned. The tuples should be in (`height`, `width`, `channels`) format. - coverage_ratio: float - The ratio of the training image to be trained on. Dictates how much of the image will be - cropped out. E.G: a coverage ratio of 0.625 will result in cropping a 160px box from a - 256px image (:math:`256 * 0.625 = 160`). - color_order: ["rgb", "bgr"] - The color order that the model expects as input - augment_color: bool - ``True`` if color is to be augmented, otherwise ``False`` - no_flip: bool - ``True`` if the image shouldn't be randomly flipped as part of augmentation, otherwise - ``False`` - no_warp: bool - ``True`` if the image shouldn't be warped as part of augmentation, otherwise ``False`` - warp_to_landmarks: bool - ``True`` if the random warp method should warp to similar landmarks from the other side, - ``False`` if the standard random warp method should be used. - face_cache: dict - A thread safe dictionary containing a cache of information relating to all faces being - trained on + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding config: dict The configuration `dict` generated from :file:`config.train.ini` containing the trainer plugin configuration options. + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final augmented data from. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. """ - def __init__(self, model_input_size, model_output_shapes, coverage_ratio, color_order, - augment_color, no_flip, no_warp, warp_to_landmarks, config): - logger.debug("Initializing %s: (model_input_size: %s, model_output_shapes: %s, " - "coverage_ratio: %s, color_order: %s, augment_color: %s, no_flip: %s, " - "no_warp: %s, warp_to_landmarks: %s, config: %s)", - self.__class__.__name__, model_input_size, model_output_shapes, - coverage_ratio, color_order, augment_color, no_flip, no_warp, - warp_to_landmarks, config) - self._config = config - self._model_input_size = model_input_size - self._model_output_shapes = model_output_shapes - self._coverage_ratio = coverage_ratio - self._color_order = color_order.lower() - self._augment_color = augment_color - self._no_flip = no_flip - self._warp_to_landmarks = warp_to_landmarks - self._no_warp = no_warp - - # Batchsize and processing class are set when this class is called by a feeder - # from lib.training_data - self._batchsize = 0 - self._face_cache = None - self._nearest_landmarks = dict() - self._processing = None - logger.debug("Initialized %s", self.__class__.__name__) + def __init__(self, + config: ConfigType, + model: "ModelBase", + side: Literal["a", "b"], + images: List[str], + batch_size: int) -> None: + super().__init__(config, model, side, images, batch_size) + self._augment_color = not model.command_line_arguments.no_augment_color + self._no_flip = model.command_line_arguments.no_flip + self._no_warp = model.command_line_arguments.no_warp + self._warp_to_landmarks = (not self._no_warp + and model.command_line_arguments.warp_to_landmarks) + self._model_input_size = model.input_shape[1] - def minibatch_ab(self, images, batchsize, side, - do_shuffle=True, is_preview=False, is_timelapse=False): - """ A Background iterator to return augmented images, samples and targets. + if self._warp_to_landmarks: + self._face_cache.pre_fill(images, side) + self._processing = ImageAugmentation(batch_size, + self._process_size, + self._config) + self._nearest_landmarks: Dict[str, Tuple[str, ...]] = {} + logger.debug("Initialized %s", self.__class__.__name__) - The exit point from this class and the sole attribute that should be referenced. Called - from :mod:`plugins.train.trainer._base`. Returns an iterator that yields images for - training, preview and time-lapses. + def _create_targets(self, batch: np.ndarray) -> List[np.ndarray]: + """ Compile target images, with masks, for the model output sizes. Parameters ---------- - images: list - A list of image paths that will be used to compile the final augmented data from. - batchsize: int - The batchsize for this iterator. Images will be returned in :class:`numpy.ndarray` - objects of this size from the iterator. - side: {'a' or 'b'} - The side of the model that this iterator is for. - do_shuffle: bool, optional - Whether data should be shuffled prior to loading from disk. If true, each time the full - list of filenames are processed, the data will be reshuffled to make sure they are not - returned in the same order. Default: ``True`` - is_preview: bool, optional - Indicates whether this iterator is generating preview images. If ``True`` then certain - augmentations will not be performed. Default: ``False`` - is_timelapse: bool optional - Indicates whether this iterator is generating time-lapse images. If ``True``, then - certain augmentations will not be performed. Default: ``False`` + batch: :class:`numpy.ndarray` + This should be a 4-dimensional array of training images in the format (`batch size`, + `height`, `width`, `channels`). Targets should be requested after performing image + transformations but prior to performing warps. The 4th channel should be the mask. + Any channels above the 4th should be any additional area masks (e.g. eye/mouth) that + are required. - Yields - ------ - dict - The following items are contained in each `dict` yielded from this iterator: - - * **feed** (:class:`numpy.ndarray`) - The feed for the model. The array returned is \ - in the format (`batchsize`, `height`, `width`, `channels`). This is the :attr:`x` \ - parameter for :func:`keras.models.model.train_on_batch`. - - * **targets** (`list`) - A list of 4-dimensional :class:`numpy.ndarray` objects in \ - the order and size of each output of the model as defined in \ - :attr:`model_output_shapes`. the format of these arrays will be (`batchsize`, \ - `height`, `width`, `3`). This is the :attr:`y` parameter for \ - :func:`keras.models.model.train_on_batch` **NB:** masks are not included in the \ - `targets` list. If required for feeding into the Keras model, they will need to be \ - added to this list in :mod:`plugins.train.trainer._base` from the `masks` key. - - * **masks** (:class:`numpy.ndarray`) - A 4-dimensional array containing the target \ - masks in the format (`batchsize`, `height`, `width`, `1`). - - * **samples** (:class:`numpy.ndarray`) - A 4-dimensional array containing the samples \ - for feeding to the model's predict function for generating preview and time-lapse \ - samples. The array will be in the format (`batchsize`, `height`, `width`, \ - `channels`). **NB:** This item will only exist in the `dict` if :attr:`is_preview` \ - or :attr:`is_timelapse` is ``True`` + Returns + ------- + list + List of 4-dimensional target images, at all model output sizes, with masks compiled + into channels 4+ for each output size """ - logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " - "is_preview, %s, is_timelapse: %s)", len(images), batchsize, side, do_shuffle, - is_preview, is_timelapse) - self._batchsize = batchsize - self._face_cache = _get_cache(side, images, self._config) - self._processing = ImageAugmentation(batchsize, - is_preview or is_timelapse, - self._model_input_size, - self._model_output_shapes, - self._coverage_ratio, - self._config) - - if self._warp_to_landmarks and not self._face_cache.partially_loaded: - self._face_cache.pre_fill(images, side) - - args = (images, side, do_shuffle, batchsize) - batcher = BackgroundGenerator(self._minibatch, thread_count=2, args=args) - return batcher.iterator() - - # << INTERNAL METHODS >> # - def _validate_samples(self, data): - """ Ensures that the total number of images within :attr:`images` is greater or equal to - the selected :attr:`batchsize`. Raises an exception if this is not the case. """ - length = len(data) - msg = ("Number of images is lower than batch-size (Note that too few " - "images may lead to bad training). # images: {}, " - "batch-size: {}".format(length, self._batchsize)) - try: - assert length >= self._batchsize, msg - except AssertionError as err: - msg += ("\nYou should increase the number of images in your training set or lower " - "your batch-size.") - raise FaceswapError(msg) from err - - def _minibatch(self, images, side, do_shuffle, batchsize): - """ A generator function that yields the augmented, target and sample images. - see :func:`minibatch_ab` for more details on the output. """ - logger.debug("Loading minibatch generator: (image_count: %s, side: '%s', do_shuffle: %s)", - len(images), side, do_shuffle) - self._validate_samples(images) - - def _img_iter(imgs): - while True: - if do_shuffle: - shuffle(imgs) - for img in imgs: - yield img - - img_iter = _img_iter(images) - while True: - img_paths = [next(img_iter) for _ in range(batchsize)] - yield self._process_batch(img_paths, side) - - logger.debug("Finished minibatch generator: (side: '%s')", side) + logger.trace("Compiling targets: batch shape: %s", batch.shape) # type: ignore + if len(self._output_sizes) == 1 and self._output_sizes[0] == self._process_size: + # Rolling buffer here makes next to no difference, so just create array on the fly + retval = [self._to_float32(batch)] + else: + retval = [self._to_float32(np.array([cv2.resize(image, (size, size), cv2.INTER_AREA) + for image in batch])) + for size in self._output_sizes] + logger.trace("Processed targets: %s", [t.shape for t in retval]) # type: ignore + return retval - def _process_batch(self, filenames, side): + def process_batch(self, + filenames: List[str], + images: np.ndarray, + detected_faces: List[DetectedFace], + batch: np.ndarray) -> BatchType: """ Performs the augmentation and compiles target images and samples. - If this is the first time a face has been loaded, then it's meta data is extracted from the - png header and added to :attr:`_face_cache` - - See - :func:`minibatch_ab` for more details on the output. - Parameters ---------- filenames: list - List of full paths to image file names - side: str - The side of the model being trained on (`a` or `b`) - """ - logger.trace("Process batch: (filenames: '%s', side: '%s')", filenames, side) - - if not self._face_cache.cache_full: - batch = self._face_cache.cache_metadata(filenames) - else: - batch = read_image_batch(filenames) - - cache = self._face_cache.get_items(filenames) - batch, landmarks = self._crop_to_center(filenames, cache, batch, side) - batch = self._apply_mask(filenames, cache, batch, side) - processed = dict() - - # Initialize processing training size on first image - if not self._processing.initialized: - self._processing.initialize(batch.shape[1]) + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering - # Get Landmarks prior to manipulating the image - if self._warp_to_landmarks: - batch_dst_pts = self._get_closest_match(filenames, side, landmarks) - warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts) - else: - warp_kwargs = dict() + Returns + ------- + feed: list + 4-dimensional array of faces to feed the training the model (:attr:`x` parameter for + :func:`keras.models.model.train_on_batch`.). The array returned is in the format + (`batch size`, `height`, `width`, `channels`). + targets: list + List of 4-dimensional :class:`numpy.ndarray` objects in the order and size of each + output of the model. The format of these arrays will be (`batch size`, `height`, + `width`, `x`). This is the :attr:`y` parameter for + :func:`keras.models.model.train_on_batch`. The number of channels here will vary. + The first 3 channels are (rgb/bgr). The 4th channel is the face mask. Any subsequent + channels are area masks (e.g. eye/mouth masks) + """ + logger.trace("Process training: (side: '%s', filenames: '%s', images: %s, " # type:ignore + "batch: %s, detected_faces: %s)", self._side, filenames, images.shape, + batch.shape, len(detected_faces)) # Color Augmentation of the image only if self._augment_color: batch[..., :3] = self._processing.color_adjust(batch[..., :3]) # Random Transform and flip - batch = self._processing.transform(batch) + self._processing.transform(batch) + if not self._no_flip: - batch = self._processing.random_flip(batch) + self._processing.random_flip(batch) # Switch color order for RGB models - if self._color_order == "rgb": - batch[..., :3] = batch[..., [2, 1, 0]] - - # Add samples to output if this is for display - if self._processing.is_display: - processed["samples"] = batch[..., :3].astype("float32") / 255.0 + self._set_color_order(batch) # Get Targets - processed.update(self._processing.get_targets(batch)) + targets = self._create_targets(batch) - # Random Warp # TODO change masks to have a input mask and a warped target mask - if self._no_warp: - processed["feed"] = [self._processing.skip_warp(batch[..., :3])] + # TODO Look at potential for applying mask on input + # Random Warp + if self._warp_to_landmarks: + landmarks = np.array([face.aligned.landmarks for face in detected_faces]) + batch_dst_pts = self._get_closest_match(filenames, landmarks) + warp_kwargs = dict(batch_src_points=landmarks, batch_dst_points=batch_dst_pts) else: - processed["feed"] = [self._processing.warp(batch[..., :3], - self._warp_to_landmarks, - **warp_kwargs)] + warp_kwargs = {} + + warped = batch[..., :3] if self._no_warp else self._processing.warp( + batch[..., :3], + self._warp_to_landmarks, + **warp_kwargs) + + if self._model_input_size != self._process_size: + feed = self._to_float32(np.array([cv2.resize(image, + (self._model_input_size, + self._model_input_size), + cv2.INTER_AREA) + for image in warped])) + else: + feed = self._to_float32(warped) - logger.trace("Processed batch: (filenames: %s, side: '%s', processed: %s)", - filenames, - side, - {k: v.shape if isinstance(v, np.ndarray) else[i.shape for i in v] - for k, v in processed.items()}) - return processed + logger.trace("Processed batch: (filenames: %s, side: '%s', " # type: ignore + "feed: %s, targets: %s)", filenames, self._side, + [f.shape for f in feed], [t.shape for t in targets]) - def _crop_to_center(self, filenames, cache, batch, side): - """ Crops the training image out of the full extract image based on the centering used in - the user's configuration settings. + return feed, targets - If legacy extract images are being used then this just returns the extracted batch with - their corresponding landmarks. + def _get_closest_match(self, filenames: List[str], batch_src_points: np.ndarray) -> np.ndarray: + """ Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest + matched 68 point landmarks from the opposite training set. Parameters ---------- filenames: list - The list of filenames that correspond to this batch - cache: list - The list of cached items (aligned faces, masks etc.) corresponding to the batch - batch: :class:`numpy.ndarray` - The batch of faces that have been loaded from disk - side: str - '"a"' or '"b"' the side that is being processed + Filenames for current batch + batch_src_points: :class:`np.ndarray` + The source landmarks for the current batch Returns ------- - batch: :class:`numpy.ndarray` - The centered faces cropped out of the loaded batch - landmarks: :class:`numpy.ndarray` - The aligned landmarks for this batch. NB: The aligned landmarks do not directly - correspond to the size of the extracted face. They are scaled to the source training - image, not the sub-image. + :class:`np.ndarray` + Randomly selected closest matches from the other side's landmarks + """ + logger.trace("Retrieving closest matched landmarks: (filenames: '%s', " # type: ignore + "src_points: '%s')", filenames, batch_src_points) + lm_side: Literal["a", "b"] = "a" if self._side == "b" else "b" + other_cache = get_cache(lm_side) + landmarks = other_cache.aligned_landmarks + + try: + closest_matches = [self._nearest_landmarks[os.path.basename(filename)] + for filename in filenames] + except KeyError: + # Resize mismatched training image size landmarks + sizes = {side: cache.size for side, cache in zip((self._side, lm_side), + (self._face_cache, other_cache))} + if len(set(sizes.values())) > 1: + scale = sizes[self._side] / sizes[lm_side] + landmarks = {key: lms * scale for key, lms in landmarks.items()} + closest_matches = self._cache_closest_matches(filenames, batch_src_points, landmarks) + + batch_dst_points = np.array([landmarks[choice(fname)] for fname in closest_matches]) + logger.trace("Returning: (batch_dst_points: %s)", batch_dst_points.shape) # type: ignore + return batch_dst_points + + def _cache_closest_matches(self, + filenames: List[str], + batch_src_points: np.ndarray, + landmarks: Dict[str, np.ndarray]) -> List[Tuple[str, ...]]: + """ Cache the nearest landmarks for this batch + + Parameters + ---------- + filenames: list + Filenames for current batch + batch_src_points: :class:`np.ndarray` + The source landmarks for the current batch + landmarks: dict + The destination landmarks with associated filenames - Raises - ------ - FaceswapError - If Alignment information is not available for any of the images being loaded in - the batch """ - logger.trace("Cropping training images info: (filenames: %s, side: '%s')", filenames, side) - aligned = [item["aligned_face"] for item in cache] + logger.trace("Caching closest matches") # type:ignore + dst_landmarks = list(landmarks.items()) + dst_points = np.array([lm[1] for lm in dst_landmarks]) + batch_closest_matches: List[Tuple[str, ...]] = [] - if self._face_cache.extract_version == 1.0: - # Legacy extract. Don't crop, just return batch with landmarks - return batch, np.array([face.landmarks for face in aligned]) + for filename, src_points in zip(filenames, batch_src_points): + closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] + closest_matches = tuple(dst_landmarks[i][0] for i in closest) + self._nearest_landmarks[os.path.basename(filename)] = closest_matches + batch_closest_matches.append(closest_matches) + logger.trace("Cached closest matches") # type:ignore + return batch_closest_matches - landmarks = np.array([face.landmarks for face in aligned]) - cropped = np.array([align.extract_face(img) for align, img in zip(aligned, batch)]) - return cropped, landmarks - def _apply_mask(self, filenames, cache, batch, side): - """ Applies the mask to the 4th channel of the image. If masks are not being used - applies a dummy all ones mask. +class PreviewDataGenerator(DataGenerator): + """ Generator for compiling images for generating previews. - If the configuration options `eye_multiplier` and/or `mouth_multiplier` are greater than 1 - then these masks are applied to the final channels of the batch respectively. + This class is called from :mod:`plugins.train.trainer._base` and launches a background + iterator that compiles sample preview data for feeding the model's predict function and for + display. + + Parameters + ---------- + model: :class:`~plugins.train.model.ModelBase` + The model that this data generator is feeding + config: dict + The configuration `dict` generated from :file:`config.train.ini` containing the trainer + plugin configuration options. + side: {'a' or 'b'} + The side of the model that this iterator is for. + images: list + A list of image paths that will be used to compile the final images. + batch_size: int + The batch size for this iterator. Images will be returned in :class:`numpy.ndarray` + objects of this size from the iterator. + """ + def _create_samples(self, + images: np.ndarray, + detected_faces: List[DetectedFace]) -> List[np.ndarray]: + """ Compile the 'sample' images. These are the 100% coverage images which hold the model + output in the preview window. Parameters ---------- - filenames: list - The list of filenames that correspond to this batch - cache: list - The list of cached items (aligned faces, masks etc.) corresponding to the batch - batch: :class:`numpy.ndarray` - The batch of faces that have been loaded from disk - side: str - '"a"' or '"b"' the side that is being processed + images: :class:`numpy.ndarray` + The original batch of images as loaded from disk. + detected_faces: list + List of :class:`~lib.align.DetectedFace` for the current batch Returns ------- - :class:`numpy.ndarray` - The batch with masks applied to the final channels + list + List of 4-dimensional target images, at final model output size """ - logger.trace("Input filenames: %s, batch shape: %s, side: %s", - filenames, batch.shape, side) - size = batch.shape[1] - - for key in ("mask", "mask_eyes", "mask_mouth"): - lookup = cache[0].get(key) - if lookup is None and key != "mask": - continue - - if lookup is None and key == "mask": - logger.trace("Creating dummy masks. side: %s", side) - masks = np.ones_like(batch[..., :1], dtype=batch.dtype) - else: - logger.trace("Obtaining masks for batch. (key: %s side: %s)", key, side) - - masks = np.array([self._get_mask(item[key], size) - for item in cache], dtype=batch.dtype) - masks = self._resize_masks(size, masks) - logger.trace("masks: (key: %s, shape: %s)", key, masks.shape) - batch = np.concatenate((batch, masks), axis=-1) - logger.trace("Output batch shape: %s, side: %s", batch.shape, side) - return batch - - @classmethod - def _get_mask(cls, item, size): - """ Decompress zipped eye and mouth masks, or return the stored mask + logger.trace("Compiling samples: images shape: %s, detected_faces: %s ", # type: ignore + images.shape, len(detected_faces)) + output_size = self._output_sizes[-1] + full_size = 2 * int(np.rint((output_size / self._coverage_ratio) / 2)) + + assert self._config["centering"] in get_args(CenteringType) + retval = np.empty((full_size, full_size, 3), dtype="float32") + retval = self._to_float32(np.array([AlignedFace(face.landmarks_xy, + image=images[idx], + centering=cast(CenteringType, + self._config["centering"]), + size=full_size, + dtype="uint8", + is_aligned=True).face + for idx, face in enumerate(detected_faces)])) + + logger.trace("Processed samples: %s", retval.shape) # type: ignore + return [retval] + + def process_batch(self, + filenames: List[str], + images: np.ndarray, + detected_faces: List[DetectedFace], + batch: np.ndarray) -> BatchType: + """ Creates the full size preview images and the sub-cropped images for feeding the model's + predict function. Parameters ---------- - item: :class:`lib.align.Mask` or `bytes` - Either a stored face mask object or a zipped eye or mouth mask - size: int - The size of the stored eye or mouth mask for reshaping + filenames: list + List of full paths to image file names for a single batch + images: :class:`numpy.ndarray` + The batch of faces corresponding to the filenames + detected_faces: list + List of :class:`~lib.align.DetectedFace` objects with aligned data and masks loaded for + the current batch + batch: :class:`numpy.ndarray` + The pre-allocated batch with images and masks populated for the selected coverage and + centering Returns ------- - class:`numpy.ndarray` - The decompressed mask + feed: list + List of 4-dimensional :class:`numpy.ndarray` objects at model input size for feeding + the model's predict function. The first 3 channels are (rgb/bgr). The 4th channel is + the face mask. + samples: list + 4-dimensional array containing the 100% coverage images at the model's centering for + for generating previews. The array returned is in the format + (`batch size`, `height`, `width`, `channels`). """ - if isinstance(item, bytes): - retval = np.frombuffer(decompress(item), dtype="uint8").reshape(size, size, 1) - else: - retval = item.mask - return retval + logger.trace("Process preview: (side: '%s', filenames: '%s', images: %s, " # type:ignore + "batch: %s, detected_faces: %s)", self._side, filenames, images.shape, + batch.shape, len(detected_faces)) - @classmethod - def _resize_masks(cls, target_size, masks): - """ Resize the masks to the target size """ - logger.trace("target size: %s, masks shape: %s", target_size, masks.shape) - mask_size = masks.shape[1] - if target_size == mask_size: - logger.trace("Mask and targets the same size. Not resizing") - return masks - interpolator = cv2.INTER_CUBIC if mask_size < target_size else cv2.INTER_AREA - masks = np.array([cv2.resize(mask, - (target_size, target_size), - interpolation=interpolator)[..., None] - for mask in masks]) - logger.trace("Resized masks: %s", masks.shape) - return masks - - def _get_closest_match(self, filenames, side, batch_src_points): - """ Only called if the :attr:`_warp_to_landmarks` is ``True``. Gets the closest - matched 68 point landmarks from the opposite training set. """ - logger.trace("Retrieving closest matched landmarks: (filenames: '%s', src_points: '%s'", - filenames, batch_src_points) - lm_side = "a" if side == "b" else "b" - landmarks = _FACE_CACHES[lm_side].aligned_landmarks - - closest_matches = [self._nearest_landmarks.get(os.path.basename(filename)) - for filename in filenames] - if None in closest_matches: - # Resize mismatched training image size landmarks - sizes = {side: cache.crop_size for side, cache in _FACE_CACHES.items()} - if len(set(sizes.values())) > 1: - scale = sizes[side] / sizes[lm_side] - landmarks = {key: lms * scale for key, lms in landmarks.items()} - closest_matches = self._cache_closest_matches(filenames, batch_src_points, landmarks) + self._set_color_order(batch) # Switch color order for RGB models - batch_dst_points = np.array([landmarks[choice(fname)] for fname in closest_matches]) - logger.trace("Returning: (batch_dst_points: %s)", batch_dst_points.shape) - return batch_dst_points + if not self._use_mask: + mask = np.zeros_like(batch[..., 0])[..., None] + 255 + batch = np.concatenate([batch, mask], axis=-1) - def _cache_closest_matches(self, filenames, batch_src_points, landmarks): - """ Cache the nearest landmarks for this batch """ - logger.trace("Caching closest matches") - dst_landmarks = list(landmarks.items()) - dst_points = np.array([lm[1] for lm in dst_landmarks]) - batch_closest_matches = list() + feed = self._to_float32(batch[..., :4]) - for filename, src_points in zip(filenames, batch_src_points): - closest = (np.mean(np.square(src_points - dst_points), axis=(1, 2))).argsort()[:10] - closest_matches = tuple(dst_landmarks[i][0] for i in closest) - self._nearest_landmarks[os.path.basename(filename)] = closest_matches - batch_closest_matches.append(closest_matches) - logger.trace("Cached closest matches") - return batch_closest_matches + samples = self._create_samples(images, detected_faces) + + logger.trace("Processed batch: (filenames: %s, side: '%s', " # type: ignore + "feed: %s, targets: %s, samples: %s)", filenames, self._side, + [f.shape for f in feed], [t.shape for t in samples]) + return feed, samples diff --git a/plugins/train/model/_base/model.py b/plugins/train/model/_base/model.py index ef0cf4f6a2..c7bab46641 100644 --- a/plugins/train/model/_base/model.py +++ b/plugins/train/model/_base/model.py @@ -204,11 +204,12 @@ def model_name(self) -> str: return self.name @property - def output_shapes(self) -> List[List[Tuple]]: + def output_shapes(self) -> List[List[Tuple[int, int, int]]]: """ list: A list of list of shape tuples for the outputs of the model with the batch dimension removed. The outer list contains 2 sub-lists (one for each side "a" and "b"). The inner sub-lists contain the output shapes for that side. """ - shapes = [tuple(K.int_shape(output)[-3:]) for output in self.model.outputs] + shapes: List[Tuple[int, int, int]] = [tuple(K.int_shape(output)[-3:]) # type: ignore + for output in self.model.outputs] return [shapes[:len(shapes) // 2], shapes[len(shapes) // 2:]] @property @@ -477,7 +478,7 @@ def _rewrite_plaid_outputs(self) -> None: self.model.output_names, new_names) self.model.output_names = new_names - def _legacy_mapping(self) -> Optional[dict]: # pylint:disable=no-self-use + def _legacy_mapping(self) -> Optional[dict]: """ The mapping of separate model files to single model layers for transferring of legacy weights. diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py index 52605fcec8..a1ae3d5526 100644 --- a/plugins/train/trainer/_base.py +++ b/plugins/train/trainer/_base.py @@ -7,10 +7,11 @@ with "original" unique code split out to the original plugin. """ -# pylint:disable=too-many-lines import logging import os +import sys import time +from typing import Callable, cast, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union import cv2 import numpy as np @@ -19,14 +20,23 @@ from tensorflow.python.framework import ( # pylint:disable=no-name-in-module errors_impl as tf_errors) -from lib.training import TrainingDataGenerator +from lib.training import PreviewDataGenerator, TrainingDataGenerator +from lib.training.generator import BatchType, ConfigType, DataGenerator from lib.utils import FaceswapError, get_backend, get_folder, get_image_paths, get_tf_version from plugins.train._config import Config +if TYPE_CHECKING: + from plugins.train.model._base import ModelBase + +if sys.version_info < (3, 8): + from typing_extensions import get_args, Literal +else: + from typing import get_args, Literal + logger = logging.getLogger(__name__) # pylint: disable=invalid-name -def _get_config(plugin_name, configfile=None): +def _get_config(plugin_name: str, configfile: Optional[str] = None) -> ConfigType: """ Return the configuration for the requested trainer. Parameters @@ -39,8 +49,8 @@ def _get_config(plugin_name, configfile=None): Returns ------- - :class:`lib.config.FaceswapConfig` - The configuration file for the requested plugin + dict + The configuration dictionary for the requested plugin """ return Config(plugin_name, configfile=configfile).config_dict @@ -65,7 +75,11 @@ class TrainerBase(): from the default :file:`.config.train.ini` file. """ - def __init__(self, model, images, batch_size, configfile): + def __init__(self, + model: "ModelBase", + images: Dict[Literal["a", "b"], List[str]], + batch_size: int, + configfile: Optional[str]) -> None: logger.debug("Initializing %s: (model: '%s', batch_size: %s)", self.__class__.__name__, model, batch_size) self._model = model @@ -81,12 +95,12 @@ def __init__(self, model, images, batch_size, configfile): self._samples = _Samples(self._model, self._model.coverage_ratio) self._timelapse = _Timelapse(self._model, self._model.coverage_ratio, - self._config.get("preview_images", 14), + int(self._config.get("preview_images", 14)), self._feeder, self._images) logger.debug("Initialized %s", self.__class__.__name__) - def _get_config(self, configfile): + def _get_config(self, configfile: Optional[str]) -> ConfigType: """ Get the saved training config options. Override any global settings with the setting provided from the model's saved config. @@ -111,7 +125,7 @@ def _get_config(self, configfile): config[key] = new_val return config - def _set_tensorboard(self): + def _set_tensorboard(self) -> tf.keras.callbacks.TensorBoard: """ Set up Tensorboard callback for logging loss. Bypassed if command line option "no-logs" has been selected. @@ -122,7 +136,7 @@ def _set_tensorboard(self): Tensorboard object for the the current training session. """ if self._model.state.current_session["no_logs"]: - logger.verbose("TensorBoard logging disabled") + logger.verbose("TensorBoard logging disabled") # type: ignore return None logger.debug("Enabling TensorBoard Logging") @@ -140,14 +154,16 @@ def _set_tensorboard(self): embeddings_metadata=None) tensorboard.set_model(self._model.model) tensorboard.on_train_begin(0) - logger.verbose("Enabled TensorBoard Logging") + logger.verbose("Enabled TensorBoard Logging") # type: ignore return tensorboard - def toggle_mask(self): + def toggle_mask(self) -> None: """ Toggle the mask overlay on or off based on user input. """ self._samples.toggle_mask_display() - def train_one_step(self, viewer, timelapse_kwargs): + def train_one_step(self, + viewer: Optional[Callable[[np.ndarray, str], None]], + timelapse_kwargs: Optional[Dict[str, str]]) -> None: """ Running training on a batch of images for each side. Triggered from the training cycle in :class:`scripts.train.Train`. @@ -171,7 +187,7 @@ def train_one_step(self, viewer, timelapse_kwargs): Parameters ---------- - viewer: :func:`scripts.train.Train._show` + viewer: :func:`scripts.train.Train._show` or ``None`` The function that will display the preview image timelapse_kwargs: dict The keyword arguments for generating time-lapse previews. If a time-lapse preview is @@ -179,16 +195,19 @@ def train_one_step(self, viewer, timelapse_kwargs): the keys being `input_a`, `input_b`, `output`. """ self._model.state.increment_iterations() - logger.trace("Training one step: (iteration: %s)", self._model.iterations) - do_preview = viewer is not None + logger.trace("Training one step: (iteration: %s)", self._model.iterations) # type: ignore snapshot_interval = self._model.command_line_arguments.snapshot_interval do_snapshot = (snapshot_interval != 0 and self._model.iterations - 1 >= snapshot_interval and (self._model.iterations - 1) % snapshot_interval == 0) model_inputs, model_targets = self._feeder.get_batch() + if get_backend() == "amd": # Expand out AMD inputs + targets + model_inputs = [inp for side in model_inputs for inp in side] # type: ignore + model_targets = [tgt for side in model_targets for tgt in side] # type: ignore + try: - loss = self._model.model.train_on_batch(model_inputs, y=model_targets) + loss: List[float] = self._model.model.train_on_batch(model_inputs, y=model_targets) except tf_errors.ResourceExhaustedError as err: msg = ("You do not have enough GPU memory available to train the selected model at " "the selected settings. You can try a number of things:" @@ -220,23 +239,11 @@ def train_one_step(self, viewer, timelapse_kwargs): self._log_tensorboard(loss) loss = self._collate_and_store_loss(loss[1:]) self._print_loss(loss) - if do_snapshot: self._model.snapshot() + self._update_viewers(viewer, timelapse_kwargs) - if do_preview: - self._feeder.generate_preview(do_preview) - self._samples.images = self._feeder.compile_sample(None) - samples = self._samples.show_sample() - if samples is not None: - viewer(samples, - "Training - 'S': Save Now. 'R': Refresh Preview. 'M': Toggle Mask. 'F': " - "Toggle Screen Fit-Actual Size. 'ENTER': Save and Quit") - - if timelapse_kwargs: - self._timelapse.output_timelapse(timelapse_kwargs) - - def _log_tensorboard(self, loss): + def _log_tensorboard(self, loss: List[float]) -> None: """ Log current loss to Tensorboard log files Parameters @@ -246,7 +253,7 @@ def _log_tensorboard(self, loss): """ if not self._tensorboard: return - logger.trace("Updating TensorBoard log") + logger.trace("Updating TensorBoard log") # type: ignore logs = {log[0]: log[1] for log in zip(self._model.state.loss_names, loss)} @@ -262,7 +269,7 @@ def _log_tensorboard(self, loss): else: self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs) - def _collate_and_store_loss(self, loss): + def _collate_and_store_loss(self, loss: List[float]) -> List[float]: """ Collate the loss into totals for each side. The losses are summed into a total for each side. Loss totals are added to @@ -273,12 +280,13 @@ def _collate_and_store_loss(self, loss): Parameters ---------- loss: list - The list of loss ``floats`` for this iteration. + The list of loss ``floats`` for each side this iteration (excluding total combined + loss) Returns ------- list - List of 2 ``floats`` which is the total loss for each side + List of 2 ``floats`` which is the total loss for each side (eg sum of face + mask loss) Raises ------ @@ -294,10 +302,10 @@ def _collate_and_store_loss(self, loss): split = len(loss) // 2 combined_loss = [sum(loss[:split]), sum(loss[split:])] self._model.add_history(combined_loss) - logger.trace("original loss: %s, comibed_loss: %s", loss, combined_loss) + logger.trace("original loss: %s, combined_loss: %s", loss, combined_loss) # type: ignore return combined_loss - def _print_loss(self, loss): + def _print_loss(self, loss: List[float]) -> None: """ Outputs the loss for the current iteration to the console. Parameters @@ -316,7 +324,32 @@ def _print_loss(self, loss): logger.warning("Swallowed OS Error caused by Tensorflow distributed training. output " "line: %s, error: %s", output, str(err)) - def clear_tensorboard(self): + def _update_viewers(self, + viewer: Optional[Callable[[np.ndarray, str], None]], + timelapse_kwargs: Optional[Dict[str, str]]) -> None: + """ Update the preview viewer and timelapse output + + Parameters + ---------- + viewer: :func:`scripts.train.Train._show` or ``None`` + The function that will display the preview image + timelapse_kwargs: dict + The keyword arguments for generating time-lapse previews. If a time-lapse preview is + not required then this should be ``None``. Otherwise all values should be full paths + the keys being `input_a`, `input_b`, `output`. + """ + if viewer is not None: + self._samples.images = self._feeder.generate_preview() + samples = self._samples.show_sample() + if samples is not None: + viewer(samples, + "Training - 'S': Save Now. 'R': Refresh Preview. 'M': Toggle Mask. 'F': " + "Toggle Screen Fit-Actual Size. 'ENTER': Save and Quit") + + if timelapse_kwargs: + self._timelapse.output_timelapse(timelapse_kwargs) + + def clear_tensorboard(self) -> None: """ Stop Tensorboard logging. Tensorboard logging needs to be explicitly shutdown on training termination. Called from @@ -342,77 +375,84 @@ class _Feeder(): config: :class:`lib.config.FaceswapConfig` The configuration for this trainer """ - def __init__(self, images, model, batch_size, config): + def __init__(self, + images: Dict[Literal["a", "b"], List[str]], + model: 'ModelBase', + batch_size: int, + config: ConfigType) -> None: logger.debug("Initializing %s: num_images: %s, batch_size: %s, config: %s)", - self.__class__.__name__, len(images), batch_size, config) + self.__class__.__name__, {k: len(v) for k, v in images.items()}, batch_size, + config) self._model = model self._images = images + self._batch_size = batch_size self._config = config - self._target = {} - self._samples = {} - self._masks = {} - - self._feeds = {side: self._load_generator(idx).minibatch_ab(images[side], batch_size, side) - for idx, side in enumerate(("a", "b"))} + self._feeds = {side: self._load_generator(side, False).minibatch_ab() + for side in get_args(Literal["a", "b"])} self._display_feeds = dict(preview=self._set_preview_feed(), timelapse={}) logger.debug("Initialized %s:", self.__class__.__name__) - def _load_generator(self, output_index): + def _load_generator(self, + side: Literal["a", "b"], + is_display: bool, + batch_size: Optional[int] = None, + images: Optional[List[str]] = None) -> DataGenerator: """ Load the :class:`~lib.training_data.TrainingDataGenerator` for this feeder. Parameters ---------- - output_index: int - The output index from the model to get output shapes for + side: ["a", "b"] + The side of the model to load the generator for + is_display: bool + ``True`` if the generator is for creating preview/time-lapse images. ``False`` if it is + for creating training images + batch_size: int, optional + If ``None`` then the batch size selected in command line arguments is used, otherwise + the batch size provided here is used. + images: list, optional. Default: ``None`` + If provided then this will be used as the list of images for the generator. If ``None`` + then the training folder images for the side will be used. Default: ``None`` Returns ------- :class:`~lib.training_data.TrainingDataGenerator` The training data generator """ - logger.debug("Loading generator") - input_size = self._model.model.input_shape[output_index][1] - output_shapes = self._model.output_shapes[output_index] - logger.debug("input_size: %s, output_shapes: %s", input_size, output_shapes) - generator = TrainingDataGenerator(input_size, - output_shapes, - self._model.coverage_ratio, - self._model.color_order, - not self._model.command_line_arguments.no_augment_color, - self._model.command_line_arguments.no_flip, - self._model.command_line_arguments.no_warp, - self._model.command_line_arguments.warp_to_landmarks, - self._config) - return generator - - def _set_preview_feed(self): + logger.debug("Loading generator, side: %s, is_display: %s, batch_size: %s", + side, is_display, batch_size) + generator = PreviewDataGenerator if is_display else TrainingDataGenerator + retval = generator(self._config, + self._model, + side, + self._images[side] if images is None else images, + self._batch_size if batch_size is None else batch_size) + return retval + + def _set_preview_feed(self) -> Dict[Literal["a", "b"], Generator[BatchType, None, None]]: """ Set the preview feed for this feeder. - Creates a generator from :class:`lib.training_data.TrainingDataGenerator` specifically + Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically for previews for the feeder. Returns ------- dict - The side ("a" or "b") as key, :class:`~lib.training_data.TrainingDataGenerator` as + The side ("a" or "b") as key, :class:`~lib.training_data.PreviewDataGenerator` as value. """ - retval = {} - for idx, side in enumerate(("a", "b")): + retval: Dict[Literal["a", "b"], Generator[BatchType, None, None]] = {} + for side in get_args(Literal["a", "b"]): logger.debug("Setting preview feed: (side: '%s')", side) - preview_images = self._config.get("preview_images", 14) + preview_images = int(self._config.get("preview_images", 14)) preview_images = min(max(preview_images, 2), 16) batchsize = min(len(self._images[side]), preview_images) - retval[side] = self._load_generator(idx).minibatch_ab(self._images[side], - batchsize, - side, - do_shuffle=True, - is_preview=True) - logger.debug("Set preview feed. Batchsize: %s", batchsize) + retval[side] = self._load_generator(side, + True, + batch_size=batchsize).minibatch_ab() return retval - def get_batch(self): + def get_batch(self) -> Tuple[List[List[np.ndarray]], ...]: """ Get the feed data and the targets for each training side for feeding into the model's train function. @@ -423,110 +463,79 @@ def get_batch(self): model_targets: list The targets for the model for each side A and B """ - model_inputs = [] - model_targets = [] + model_inputs: List[List[np.ndarray]] = [] + model_targets: List[List[np.ndarray]] = [] for side in ("a", "b"): - batch = next(self._feeds[side]) - side_inputs = batch["feed"] - side_targets = self._compile_mask_targets(batch["targets"], - batch["masks"], - batch.get("additional_masks", None)) - if self._model.config["learn_mask"]: - side_targets = side_targets + [batch["masks"]] - logger.trace("side: %s, input_shapes: %s, target_shapes: %s", - side, [i.shape for i in side_inputs], [i.shape for i in side_targets]) - if get_backend() == "amd": - model_inputs.extend(side_inputs) - model_targets.extend(side_targets) - else: - model_inputs.append(side_inputs) - model_targets.append(side_targets) - return model_inputs, model_targets + side_feed, side_targets = next(self._feeds[side]) + if self._model.config["learn_mask"]: # Add the face mask as it's own target + side_targets += [side_targets[-1][..., 3][..., None]] + logger.trace("side: %s, input_shapes: %s, target_shapes: %s", # type: ignore + side, side_feed.shape, [i.shape for i in side_targets]) + model_inputs.append([side_feed]) + model_targets.append(side_targets) - def _compile_mask_targets(self, targets, masks, additional_masks): - """ Compile the masks into the targets for penalized loss and for targeted learning. + return model_inputs, model_targets - Penalized loss expects the target mask to be included for all outputs in the 4th channel - of the targets. Any additional masks are placed into subsequent channels for extraction - by the relevant loss functions. + def generate_preview(self, + is_timelapse: bool = False) -> Dict[Literal["a", "b"], List[np.ndarray]]: + """ Generate the images for preview window or timelapse Parameters ---------- - targets: list - The targets for the model, with the mask as the final entry in the list - masks: list - The masks for the model - additional_masks: list or ``None`` - Any additional masks for the model, or ``None`` if no additional masks are required + is_timelapse, bool, optional + ``True`` if preview is to be generated for a Timelapse otherwise ``False``. + Default: ``False`` Returns ------- - list - The targets for the model with the mask compiled into the 4th channel. The original - mask is still output as the final item in the list - """ - if not self._model.config["penalized_mask_loss"] and additional_masks is None: - logger.trace("No masks to compile. Returning targets") - return targets - - if not self._model.config["penalized_mask_loss"] and additional_masks is not None: - masks = additional_masks - elif additional_masks is not None: - masks = np.concatenate((masks, additional_masks), axis=-1) - - for idx, tgt in enumerate(targets): - tgt_dim = tgt.shape[1] - if tgt_dim == masks.shape[1]: - add_masks = masks - else: - add_masks = np.array([cv2.resize(mask, (tgt_dim, tgt_dim)) - for mask in masks]) - if add_masks.ndim == 3: - add_masks = add_masks[..., None] - targets[idx] = np.concatenate((tgt, add_masks), axis=-1) - logger.trace("masks added to targets: %s", [tgt.shape for tgt in targets]) - return targets - - def generate_preview(self, do_preview): - """ Generate the preview images. - - Parameters - ---------- - do_preview: bool - Whether the previews should be generated. ``True`` if they should ``False`` if they - should not be generated, in which case currently stored previews should be deleted. + dict + Dictionary for side A and B of list of numpy arrays corresponding to the + samples, targets and masks for this preview """ - if not do_preview: - self._samples = {} - self._target = {} - self._masks = {} - return - logger.debug("Generating preview") - for side in ("a", "b"): - batch = next(self._display_feeds["preview"][side]) - self._samples[side] = batch["samples"] - self._target[side] = batch["targets"][-1] - self._masks[side] = batch["masks"] - - def compile_sample(self, batch_size, samples=None, images=None, masks=None): + logger.debug("Generating preview (is_timelapse: %s)", is_timelapse) + + batchsizes: List[int] = [] + feed: Dict[Literal["a", "b"], np.ndarray] = {} + samples: Dict[Literal["a", "b"], np.ndarray] = {} + masks: Dict[Literal["a", "b"], np.ndarray] = {} + + # MyPy can't recurse into nested dicts to get the type :( + iterator = cast(Dict[Literal["a", "b"], Generator[BatchType, None, None]], + self._display_feeds["timelapse" if is_timelapse else "preview"]) + for side in get_args(Literal["a", "b"]): + side_feed, side_samples = next(iterator[side]) + batchsizes.append(len(side_samples[0])) + samples[side] = side_samples[0] + feed[side] = side_feed[..., :3] + masks[side] = side_feed[..., 3][..., None] + + logger.debug("Generated samples: is_timelapse: %s, images: %s", is_timelapse, + {key: {k: v.shape for k, v in item.items()} + for key, item + in zip(("feed", "samples", "sides"), (feed, samples, masks))}) + return self.compile_sample(min(batchsizes), feed, samples, masks) + + def compile_sample(self, + image_count: int, + feed: Dict[Literal["a", "b"], np.ndarray], + samples: Dict[Literal["a", "b"], np.ndarray], + masks: Dict[Literal["a", "b"], np.ndarray] + ) -> Dict[Literal["a", "b"], List[np.ndarray]]: """ Compile the preview samples for display. Parameters ---------- - batch_size: int - The requested batch size for each training iterations - samples: dict, optional - Dictionary for side "a", "b" of :class:`numpy.ndarray`. The sample images that should - be used for creating the preview. If ``None`` then the samples will be generated from - the internal random image generator. Default: ``None`` - images: dict, optional - Dictionary for side "a", "b" of :class:`numpy.ndarray`. The target images that should - be used for creating the preview. If ``None`` then the targets will be generated from - the internal random image generator. Default: ``None`` - masks: dict, optional + image_count: int + The number of images to limit the sample output to. + feed: dict + Dictionary for side "a", "b" of :class:`numpy.ndarray`. The images that should be fed + into the model for obtaining a prediction + samples: dict + Dictionary for side "a", "b" of :class:`numpy.ndarray`. The 100% coverage target images + that should be used for creating the preview. + masks: dict Dictionary for side "a", "b" of :class:`numpy.ndarray`. The masks that should be used - for creating the preview. If ``None`` then the masks will be generated from the - internal random image generator. Default: ``None`` + for creating the preview. Returns ------- @@ -534,65 +543,46 @@ def compile_sample(self, batch_size, samples=None, images=None, masks=None): The list of samples, targets and masks as :class:`numpy.ndarrays` for creating a preview image """ - num_images = self._config.get("preview_images", 14) - num_images = min(batch_size, num_images) if batch_size is not None else num_images - retval = {} - for side in ("a", "b"): + num_images = min(image_count, int(self._config.get("preview_images", 14))) + retval: Dict[Literal["a", "b"], List[np.ndarray]] = {} + for side in get_args(Literal["a", "b"]): logger.debug("Compiling samples: (side: '%s', samples: %s)", side, num_images) - side_images = images[side] if images is not None else self._target[side] - side_masks = masks[side] if masks is not None else self._masks[side] - side_samples = samples[side] if samples is not None else self._samples[side] - retval[side] = [side_samples[0:num_images], - side_images[0:num_images], - side_masks[0:num_images]] + retval[side] = [feed[side][0:num_images], + samples[side][0:num_images], + masks[side][0:num_images]] + logger.debug("Compiled Samples: %s", {k: [i.shape for i in v] for k, v in retval.items()}) return retval - def compile_timelapse_sample(self): - """ Compile the sample images for creating a time-lapse frame. - - Returns - ------- - dict - For sides "a" and "b"; The list of samples, targets and masks as - :class:`numpy.ndarrays` for creating a time-lapse frame - """ - batchsizes = [] - samples = {} - images = {} - masks = {} - for side in ("a", "b"): - batch = next(self._display_feeds["timelapse"][side]) - batchsizes.append(len(batch["samples"])) - samples[side] = batch["samples"] - images[side] = batch["targets"][-1] - masks[side] = batch["masks"] - batchsize = min(batchsizes) - sample = self.compile_sample(batchsize, samples=samples, images=images, masks=masks) - return sample - - def set_timelapse_feed(self, images, batch_size): + def set_timelapse_feed(self, + images: Dict[Literal["a", "b"], List[str]], + batch_size: int) -> None: """ Set the time-lapse feed for this feeder. - Creates a generator from :class:`lib.training_data.TrainingDataGenerator` specifically + Creates a generator from :class:`lib.training_data.PreviewDataGenerator` specifically for generating time-lapse previews for the feeder. Parameters ---------- - images: list - The list of full paths to the images for creating the time-lapse for this - :class:`_Feeder` + images: dict + The list of full paths to the images for creating the time-lapse for each side batch_size: int The number of images to be used to create the time-lapse preview. """ logger.debug("Setting time-lapse feed: (input_images: '%s', batch_size: %s)", images, batch_size) - for idx, side in enumerate(("a", "b")): - self._display_feeds["timelapse"][side] = self._load_generator(idx).minibatch_ab( - images[side][:batch_size], - batch_size, - side, - do_shuffle=False, - is_timelapse=True) + + # MyPy can't recurse into nested dicts to get the type :( + iterator = cast(Dict[Literal["a", "b"], Generator[BatchType, None, None]], + self._display_feeds["timelapse"]) + + for side in get_args(Literal["a", "b"]): + imgs = images[side] + logger.debug("Setting preview feed: (side: '%s', images: %s)", side, len(imgs)) + + iterator[side] = self._load_generator(side, + True, + batch_size=batch_size, + images=imgs).minibatch_ab() logger.debug("Set time-lapse feed: %s", self._display_feeds["timelapse"]) @@ -613,25 +603,25 @@ class _Samples(): # pylint:disable=too-few-public-methods dictionary should contain 2 keys ("a" and "b") with the values being the training images for generating samples corresponding to each side. """ - def __init__(self, model, coverage_ratio): + def __init__(self, model: "ModelBase", coverage_ratio: float) -> None: logger.debug("Initializing %s: model: '%s', coverage_ratio: %s)", self.__class__.__name__, model, coverage_ratio) self._model = model self._display_mask = model.config["learn_mask"] or model.config["penalized_mask_loss"] - self.images = {} + self.images: Dict[Literal["a", "b"], List[np.ndarray]] = {} self._coverage_ratio = coverage_ratio logger.debug("Initialized %s", self.__class__.__name__) - def toggle_mask_display(self): + def toggle_mask_display(self) -> None: """ Toggle the mask overlay on or off depending on user input. """ if not (self._model.config["learn_mask"] or self._model.config["penalized_mask_loss"]): return display_mask = not self._display_mask - print("\n") # Break to not garble loss output + print("") # Break to not garble loss output logger.info("Toggling mask display %s...", "on" if display_mask else "off") self._display_mask = display_mask - def show_sample(self): + def show_sample(self) -> np.ndarray: """ Compile a preview image. Returns @@ -640,49 +630,23 @@ def show_sample(self): A compiled preview image ready for display or saving """ logger.debug("Showing sample") - feeds = {} - figures = {} - headers = {} - for idx, side in enumerate(("a", "b")): - samples = self.images[side] - faces = samples[1] + feeds: Dict[Literal["a", "b"], np.ndarray] = {} + for idx, side in enumerate(get_args(Literal["a", "b"])): input_shape = self._model.model.input_shape[idx][1:] - if input_shape[0] / faces.shape[1] != 1.0: - feeds[side] = self._resize_sample(side, faces, input_shape[0]) + if input_shape[0] / self.images[side][0].shape[1] != 1.0: + feeds[side] = self._resize_sample(side, self.images[side][1], input_shape[0]) feeds[side] = feeds[side].reshape((-1, ) + input_shape) else: - feeds[side] = faces + feeds[side] = self.images[side][0] preds = self._get_predictions(feeds["a"], feeds["b"]) - - for side, samples in self.images.items(): - other_side = "a" if side == "b" else "b" - predictions = [preds[f"{side}_{side}"], - preds[f"{other_side}_{side}"]] - display = self._to_full_frame(side, samples, predictions) - headers[side] = self._get_headers(side, display[0].shape[1]) - figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) - if self.images[side][0].shape[0] % 2 == 1: - figures[side] = np.concatenate([figures[side], - np.expand_dims(figures[side][0], 0)]) - - width = 4 - side_cols = width // 2 - if side_cols != 1: - headers = self._duplicate_headers(headers, side_cols) - - header = np.concatenate([headers["a"], headers["b"]], axis=1) - figure = np.concatenate([figures["a"], figures["b"]], axis=0) - height = int(figure.shape[0] / width) - figure = figure.reshape((width, height) + figure.shape[1:]) - figure = _stack_images(figure) - figure = np.concatenate((header, figure), axis=0) - - logger.debug("Compiled sample") - return np.clip(figure * 255, 0, 255).astype('uint8') + return self._compile_preview(preds) @classmethod - def _resize_sample(cls, side, sample, target_size): + def _resize_sample(cls, + side: Literal["a", "b"], + sample: np.ndarray, + target_size: int) -> np.ndarray: """ Resize a given image to the target size. Parameters @@ -710,15 +674,15 @@ def _resize_sample(cls, side, sample, target_size): logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) return retval - def _get_predictions(self, feed_a, feed_b): + def _get_predictions(self, feed_a: np.ndarray, feed_b: np.ndarray) -> Dict[str, np.ndarray]: """ Feed the samples to the model and return predictions Parameters ---------- - feed_a: list - List of :class:`numpy.ndarray` of feed images for the "a" side - feed_a: list - List of :class:`numpy.ndarray` of feed images for the "b" side + feed_a: :class:`numpy.ndarray` + Feed images for the "a" side + feed_a: :class:`numpy.ndarray` + Feed images for the "b" side Returns ------- @@ -726,7 +690,7 @@ def _get_predictions(self, feed_a, feed_b): List of :class:`numpy.ndarray` of predictions received from the model """ logger.debug("Getting Predictions") - preds = {} + preds: Dict[str, np.ndarray] = {} standard = self._model.model.predict([feed_a, feed_b], verbose=0) swapped = self._model.model.predict([feed_b, feed_a], verbose=0) @@ -751,7 +715,51 @@ def _get_predictions(self, feed_a, feed_b): logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) return preds - def _to_full_frame(self, side, samples, predictions): + def _compile_preview(self, predictions: Dict[str, np.ndarray]) -> np.ndarray: + """ Compile predictions and images into the final preview image. + + Parameters + ---------- + predictions: dict + The predictions from the model + + Returns + ------- + :class:`numpy.ndarry` + A compiled preview image ready for display or saving + """ + figures: Dict[Literal["a", "b"], np.ndarray] = {} + headers: Dict[Literal["a", "b"], np.ndarray] = {} + + for side, samples in self.images.items(): + other_side = "a" if side == "b" else "b" + preds = [predictions[f"{side}_{side}"], + predictions[f"{other_side}_{side}"]] + display = self._to_full_frame(side, samples, preds) + headers[side] = self._get_headers(side, display[0].shape[1]) + figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) + if self.images[side][1].shape[0] % 2 == 1: + figures[side] = np.concatenate([figures[side], + np.expand_dims(figures[side][0], 0)]) + + width = 4 + if width // 2 != 1: + headers = self._duplicate_headers(headers, width // 2) + + header = np.concatenate([headers["a"], headers["b"]], axis=1) + figure = np.concatenate([figures["a"], figures["b"]], axis=0) + height = int(figure.shape[0] / width) + figure = figure.reshape((width, height) + figure.shape[1:]) + figure = _stack_images(figure) + figure = np.concatenate((header, figure), axis=0) + + logger.debug("Compiled sample") + return np.clip(figure * 255, 0, 255).astype('uint8') + + def _to_full_frame(self, + side: Literal["a", "b"], + samples: List[np.ndarray], + predictions: List[np.ndarray]) -> List[np.ndarray]: """ Patch targets and prediction images into images of model output size. Parameters @@ -759,7 +767,7 @@ def _to_full_frame(self, side, samples, predictions): side: {"a" or "b"} The side that these samples are for samples: list - List of :class:`numpy.ndarray` of feed images and target images + List of :class:`numpy.ndarray` of feed images and sample images predictions: list List of :class: `numpy.ndarray` of predictions from the model @@ -770,7 +778,7 @@ def _to_full_frame(self, side, samples, predictions): """ logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)", side, len(samples), [pred.shape for pred in predictions]) - full, faces = samples[:2] + faces, full = samples[:2] if self._model.color_order.lower() == "rgb": # Switch color order for RGB model display full = full[..., ::-1] @@ -786,7 +794,11 @@ def _to_full_frame(self, side, samples, predictions): return images - def _process_full(self, side, images, prediction_size, color): + def _process_full(self, + side: Literal["a", "b"], + images: np.ndarray, + prediction_size: int, + color: Tuple[int, int, int]) -> np.ndarray: """ Add a frame overlay to preview images indicating the region of interest. This applies the red border that appears in the preview images. @@ -828,7 +840,7 @@ def _process_full(self, side, images, prediction_size, color): return images @classmethod - def _compile_masked(cls, faces, masks): + def _compile_masked(cls, faces: List[np.ndarray], masks: np.ndarray) -> List[np.ndarray]: """ Add the mask to the faces for masked preview. Places an opaque red layer over areas of the face that are masked out. @@ -848,6 +860,7 @@ def _compile_masked(cls, faces, masks): """ orig_masks = np.tile(1 - np.rint(masks), 3) orig_masks[np.where((orig_masks == [1., 1., 1.]).all(axis=3))] = [0., 0., 1.] + masks3: Union[List[np.ndarray], np.ndarray] = [] if faces[-1].shape[-1] == 4: # Mask contained in alpha channel of predictions pred_masks = [np.tile(1 - np.rint(face[..., -1])[..., None], 3) for face in faces[-2:]] @@ -865,15 +878,15 @@ def _compile_masked(cls, faces, masks): return retval @classmethod - def _overlay_foreground(cls, backgrounds, foregrounds): + def _overlay_foreground(cls, backgrounds: np.ndarray, foregrounds: np.ndarray) -> np.ndarray: """ Overlay the preview images into the center of the background images Parameters ---------- - backgrounds: list - List of :class:`numpy.ndarray` background images for placing the preview images onto - backgrounds: list - List of :class:`numpy.ndarray` preview images for placing onto the background images + backgrounds: :class:`numpy.ndarray` + Background images for placing the preview images onto + backgrounds: :class:`numpy.ndarray` + Preview images for placing onto the background images Returns ------- @@ -888,7 +901,7 @@ def _overlay_foreground(cls, backgrounds, foregrounds): return backgrounds @classmethod - def _get_headers(cls, side, width): + def _get_headers(cls, side: Literal["a", "b"], width: int) -> np.ndarray: """ Set header row for the final preview frame Parameters @@ -906,12 +919,11 @@ def _get_headers(cls, side, width): logger.debug("side: '%s', width: %s", side, width) titles = ("Original", "Swap") if side == "a" else ("Swap", "Original") - side = side.upper() height = int(width / 4.5) total_width = width * 3 logger.debug("height: %s, total_width: %s", height, total_width) font = cv2.FONT_HERSHEY_SIMPLEX - texts = [f"{titles[0]} ({side})", + texts = [f"{titles[0]} ({side.upper()})", f"{titles[0]} > {titles[0]}", f"{titles[0]} > {titles[1]}"] scaling = (width / 144) * 0.45 @@ -936,20 +948,22 @@ def _get_headers(cls, side, width): return header_box @classmethod - def _duplicate_headers(cls, headers, columns): + def _duplicate_headers(cls, + headers: Dict[Literal["a", "b"], np.ndarray], + columns: int) -> Dict[Literal["a", "b"], np.ndarray]: """ Duplicate headers for the number of columns displayed for each side. Parameters ---------- - headers: :class:`numpy.ndarray` - The header to be duplicated + headers: dict + The headers to be duplicated for each side columns: int The number of columns that the header needs to be duplicated for Returns ------- - :class:`numpy.ndarray` - The original headers duplicated by the number of columns + :class:dict + The original headers duplicated by the number of columns for each side """ for side, header in headers.items(): duped = tuple(header for _ in range(columns)) @@ -971,12 +985,17 @@ class _Timelapse(): # pylint:disable=too-few-public-methods The amount to scale the final preview image by. Default: `1.0` image_count: int The number of preview images to be displayed in the time-lapse - feeder: dict - The :class:`_Feeder` for generating the time-lapse images. + feeder: :class:`_Feeder` + The feeder for generating the time-lapse images. image_paths: dict The full paths to the training images for each side of the model """ - def __init__(self, model, coverage_ratio, image_count, feeder, image_paths): + def __init__(self, + model: "ModelBase", + coverage_ratio: float, + image_count: int, + feeder: _Feeder, + image_paths: Dict[Literal["a", "b"], List[str]]) -> None: logger.debug("Initializing %s: model: %s, coverage_ratio: %s, image_count: %s, " "feeder: '%s', image_paths: %s)", self.__class__.__name__, model, coverage_ratio, image_count, feeder, len(image_paths)) @@ -985,10 +1004,10 @@ def __init__(self, model, coverage_ratio, image_count, feeder, image_paths): self._model = model self._feeder = feeder self._image_paths = image_paths - self._output_file = None + self._output_file = "" logger.debug("Initialized %s", self.__class__.__name__) - def _setup(self, input_a=None, input_b=None, output=None): + def _setup(self, input_a: str, input_b: str, output: str) -> None: """ Setup the time-lapse folder locations and the time-lapse feed. Parameters @@ -1002,15 +1021,15 @@ def _setup(self, input_a=None, input_b=None, output=None): default to the model folder """ logger.debug("Setting up time-lapse") - if output is None: + if not output: output = get_folder(os.path.join(str(self._model.model_dir), f"{self._model.name}_timelapse")) - self._output_file = str(output) + self._output_file = output logger.debug("Time-lapse output set to '%s'", self._output_file) # Rewrite paths to pull from the training images so mask and face data can be accessed - images = {} - for side, input_ in zip(("a", "b"), (input_a, input_b)): + images: Dict[Literal["a", "b"], List[str]] = {} + for side, input_ in zip(get_args(Literal["a", "b"]), (input_a, input_b)): training_path = os.path.dirname(self._image_paths[side][0]) images[side] = [os.path.join(training_path, os.path.basename(pth)) for pth in get_image_paths(input_)] @@ -1021,7 +1040,7 @@ def _setup(self, input_a=None, input_b=None, output=None): self._feeder.set_timelapse_feed(images, batchsize) logger.debug("Set up time-lapse") - def output_timelapse(self, timelapse_kwargs): + def output_timelapse(self, timelapse_kwargs: Dict[str, str]) -> None: """ Generate the time-lapse samples and output the created time-lapse to the specified output folder. @@ -1036,7 +1055,7 @@ def output_timelapse(self, timelapse_kwargs): self._setup(**timelapse_kwargs) logger.debug("Getting time-lapse samples") - self._samples.images = self._feeder.compile_timelapse_sample() + self._samples.images = self._feeder.generate_preview(is_timelapse=True) logger.debug("Got time-lapse samples: %s", {side: len(images) for side, images in self._samples.images.items()}) @@ -1049,7 +1068,7 @@ def output_timelapse(self, timelapse_kwargs): logger.debug("Created time-lapse: '%s'", filename) -def _stack_images(images): +def _stack_images(images: np.ndarray) -> np.ndarray: """ Stack images evenly for preview. Parameters diff --git a/requirements/_requirements_base.txt b/requirements/_requirements_base.txt index 4bcd487e76..581c0c54db 100644 --- a/requirements/_requirements_base.txt +++ b/requirements/_requirements_base.txt @@ -1,5 +1,6 @@ tqdm>=4.64 psutil>=5.9.0 +numexpr>=2.8.3 opencv-python>=4.6.0.0 pillow>=9.2.0 scikit-learn==1.0.2; python_version < '3.9' # AMD needs version 1.0.2 and 1.1.0 not available in Python 3.7 diff --git a/scripts/extract.py b/scripts/extract.py index 91f4fa1219..6176f2488e 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -6,7 +6,7 @@ import logging import os import sys -from typing import TYPE_CHECKING, Optional +from typing import List, Dict, TYPE_CHECKING, Optional from tqdm import tqdm @@ -71,7 +71,7 @@ def __init__(self, arguments: argparse.Namespace) -> None: min_size=self._args.min_size, normalize_method=normalization, re_feed=self._args.re_feed) - self._threads = [] + self._threads: List[MultiThread] = [] self._verify_output = False logger.debug("Initialized %s", self.__class__.__name__) @@ -101,13 +101,14 @@ def _set_skip_list(self) -> None: skip_list = [] for idx, filename in enumerate(self._images.file_list): if idx % self._skip_num != 0: - logger.trace("Adding image '%s' to skip list due to extract_every_n = %s", - filename, self._skip_num) + logger.trace("Adding image '%s' to skip list due to " # type: ignore + "extract_every_n = %s", filename, self._skip_num) skip_list.append(idx) # Items may be in the alignments file if skip-existing[-faces] is selected elif os.path.basename(filename) in self._alignments.data: self._existing_count += 1 - logger.trace("Removing image: '%s' due to previously existing", filename) + logger.trace("Removing image: '%s' due to previously existing", # type: ignore + filename) skip_list.append(idx) if self._existing_count != 0: logger.info("Skipping %s frames due to skip_existing/skip_existing_faces.", @@ -142,7 +143,7 @@ def _threaded_redirector(self, task: str, io_args: Optional[tuple] = None) -> No Any arguments that need to be provided to the background function """ logger.debug("Threading task: (Task: '%s')", task) - io_args = tuple() if io_args is None else (io_args, ) + io_args = tuple() if io_args is None else io_args func = getattr(self, f"_{task}") io_thread = MultiThread(func, *io_args, thread_count=1) io_thread.start() @@ -165,7 +166,7 @@ def _load(self) -> None: load_queue.put("EOF") logger.debug("Load Images: Complete") - def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None: + def _reload(self, detected_faces: Dict[str, ExtractMedia]) -> None: """ Reload the images and pair to detected face When the extraction pipeline is running in serial mode, images are reloaded from disk, @@ -183,7 +184,7 @@ def _reload(self, detected_faces: dict[str, ExtractMedia]) -> None: if load_queue.shutdown.is_set(): logger.debug("Reload Queue: Stop signal received. Terminating") break - logger.trace("Reloading image: '%s'", filename) + logger.trace("Reloading image: '%s'", filename) # type: ignore extract_media = detected_faces.pop(filename, None) if not extract_media: logger.warning("Couldn't find faces for: %s", filename) @@ -231,8 +232,8 @@ def _run_extraction(self) -> None: if not is_final: logger.debug("Reloading images") - self._threaded_redirector("reload", detected_faces) - if not self._args.skip_saving_faces: + self._threaded_redirector("reload", (detected_faces, )) + if saver is not None: saver.close() def _check_thread_error(self) -> None: @@ -263,13 +264,13 @@ def _output_processing(self, extract_media: ExtractMedia, size: int) -> None: faces_count = len(extract_media.detected_faces) if faces_count == 0: - logger.verbose("No faces were detected in image: %s", + logger.verbose("No faces were detected in image: %s", # type: ignore os.path.basename(extract_media.filename)) if not self._verify_output and faces_count > 1: self._verify_output = True - def _output_faces(self, saver: ImagesSaver, extract_media: ExtractMedia) -> None: + def _output_faces(self, saver: Optional[ImagesSaver], extract_media: ExtractMedia) -> None: """ Output faces to save thread Set the face filename based on the frame name and put the face to the @@ -278,12 +279,12 @@ def _output_faces(self, saver: ImagesSaver, extract_media: ExtractMedia) -> None Parameters ---------- - saver: lib.images.ImagesSaver - The background saver for saving the image + saver: :class:`lib.images.ImagesSaver` or ``None`` + The background saver for saving the image or ``None`` if faces are not to be saved extract_media: :class:`~plugins.extract.pipeline.ExtractMedia` The output from :class:`~plugins.extract.Pipeline.Extractor` """ - logger.trace("Outputting faces for %s", extract_media.filename) + logger.trace("Outputting faces for %s", extract_media.filename) # type: ignore final_faces = [] filename = os.path.splitext(os.path.basename(extract_media.filename))[0] extension = ".png" @@ -299,7 +300,7 @@ def _output_faces(self, saver: ImagesSaver, extract_media: ExtractMedia) -> None source_frame_dims=extract_media.image_size)) image = encode_image(face.aligned.face, extension, metadata=meta) - if not self._args.skip_saving_faces: + if saver is not None: saver.save(output_filename, image) final_faces.append(face.to_alignment()) self._alignments.data[os.path.basename(extract_media.filename)] = dict(faces=final_faces) diff --git a/setup.cfg b/setup.cfg index 1dfb9bba15..f11334206b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,8 @@ exclude = .git, __pycache__ [mypy] [mypy-cv2.*] ignore_missing_imports = True +[mypy-fastcluster.*] +ignore_missing_imports = True [mypy-imageio.*] ignore_missing_imports = True [mypy-imageio_ffmpeg.*] @@ -16,6 +18,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-matplotlib.*] ignore_missing_imports = True +[mypy-numexpr.*] +ignore_missing_imports = True [mypy-pexpect.*] ignore_missing_imports = True [mypy-PIL.*]