Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Docker environment & web demo #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Y-tech, Kuaishou Technology


### [Project page](https://onion-liu.github.io/BlendGAN) | [Paper](https://arxiv.org/abs/2110.11728)
[![Replicate](https://replicate.com/onion-liu/blendgan/badge)](https://replicate.com/onion-liu/blendgan)

Abstract: *Generative Adversarial Networks (GANs) have made a dramatic leap in high-fidelity image synthesis and stylized face generation. Recently, a layer-swapping mechanism has been developed to improve the stylization performance. However, this method is incapable of fitting arbitrary styles in a single model and requires hundreds of style-consistent training images for each style. To address the above issues, we propose BlendGAN for arbitrary stylized face generation by leveraging a flexible blending strategy and a generic artistic dataset. Specifically, we first train a self-supervised style encoder on the generic artistic dataset to extract the representations of arbitrary styles. In addition, a weighted blending module (WBM) is proposed to blend face and style representations implicitly and control the arbitrary stylization effect. By doing so, BlendGAN can gracefully fit arbitrary styles in a unified model while avoiding case-by-case preparation of style-consistent training images. To this end, we also present a novel large-scale artistic face dataset AAHQ. Extensive experiments demonstrate that BlendGAN outperforms state-of-the-art methods in terms of visual quality and style diversity for both latent-guided and reference-guided stylized face synthesis.*

Expand Down
20 changes: 20 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- "ninja-build"
python_packages:
- "ipython==7.21.0"
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.19.4"
- "tqdm==4.54.1"
- "opencv-python==4.4.0.46"
- "scipy==1.7.2"
- "cmake==3.22.0"
run:
- pip install dlib

predict: "predict.py:Predictor"
63 changes: 63 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import tempfile
import cv2
import random
import numpy as np
from pathlib import Path
from cog import BasePredictor, Path, Input

from ffhq_dataset.gen_aligned_image import FaceAlign
from model import Generator
from psp_encoder.psp_encoders import PSPEncoder
from utils import ten2cv, cv2ten


class Predictor(BasePredictor):
def setup(self):
size = 1024
latent = 512
n_mlp = 8
self.device = 'cuda'
checkpoint = torch.load('pretrained_models/blendgan.pt')
model_dict = checkpoint['g_ema']

self.g_ema = Generator(size, latent, n_mlp, channel_multiplier=2).to(self.device)
self.g_ema.load_state_dict(model_dict)
self.g_ema.eval()
self.psp_encoder = PSPEncoder('pretrained_models/psp_encoder.pt', output_size=1024).to(self.device)
self.psp_encoder.eval()
self.fa = FaceAlign()

def predict(
self,
source: Path = Input(
description="source facial image, it will be aligned and resized to 1024x1024 first",
),
style: Path = Input(
description="style reference facial image, it will be aligned and resized to 1024x1024 first",
),
) -> Path:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

add_weight_index = 6
# face alignment
source_img = cv2.imread(str(source))
style_img = cv2.imread(str(style))
source_img_crop = self.fa.get_crop_image(source_img)
style_img_crop = self.fa.get_crop_image(style_img)
source_img_ten = cv2ten(source_img_crop, self.device)
style_img_ten = cv2ten(style_img_crop, self.device)
with torch.no_grad():
sample_style = self.g_ema.get_z_embed(style_img_ten)
sample_in = self.psp_encoder(source_img_ten)
img_out_ten, _ = self.g_ema([sample_in], z_embed=sample_style, add_weight_index=add_weight_index,
input_is_latent=True, return_latents=False, randomize_noise=False)
img_out = ten2cv(img_out_ten)
out = img_out
out_path = Path(tempfile.mkdtemp()) / "out.png"
cv2.imwrite(str(out_path), out)
return out_path