Skip to content

Commit

Permalink
Fix issue shivamswarnkar#2
Browse files Browse the repository at this point in the history
* images are generated and saved to different files (not one with all generated images)
  • Loading branch information
Kabanosk committed Mar 5, 2024
1 parent 55b6d06 commit 8baf2ea
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions DCGAN/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@
import numpy as np
from generator import Generator
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

def generate_images(args):
# set up device
device = torch.device('cuda:0'
if (torch.cuda.is_available() and args.ngpu>0)
else 'cpu')

# set up device
device = torch.device('cuda:0'
if (torch.cuda.is_available() and args.ngpu>0)
else 'cpu')
# load generator model
netG = Generator(args).to(device)
netG.load_state_dict(torch.load(args.netG))

# load generator model
netG = Generator(args).to(device)
netG.load_state_dict(torch.load(args.netG))
filename, ext = os.path.splitext(os.path.basename(args.output_path))

for i in tqdm(range(args.n)):
# create random noise
noise = torch.randn(1, args.nz, 1, 1, device=device)

# create random noise
noise = torch.randn(args.n, args.nz, 1, 1, device=device)
fake = netG(noise).detach().cpu()
img = vutils.make_grid(fake, padding=2, normalize=True)

# save image
plt.axis("off")
plt.imshow(np.transpose(img,(1,2,0)))
plt.savefig(args.output_path)
with torch.no_grad():
fake = netG(noise).detach().cpu()

img_np = np.transpose(vutils.make_grid(fake, padding=2, normalize=True).numpy(), (1, 2, 0))

new_filename = f"{filename}_{str(i).zfill(3)}{ext}"
new_output_path = os.path.join(os.path.dirname(args.output_path), new_filename)

# save image
plt.imsave(new_output_path, img_np)

0 comments on commit 8baf2ea

Please sign in to comment.