Comments (5)
Hi,
You would have to save a lot of images of yourself to a folder and use something like this dataset loader.
import glob
import random
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, folder_path, transforms_=None):
self.transform = transforms.Compose(transforms_)
self.files = sorted(glob.glob('%s/*.*' % folder_path))
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)])
img = self.transform(img)
return img
def __len__(self):
return len(self.files)
And in wgan_gp.py
:
# Dataset loader
transforms_ = [ transforms.Resize(opt.img_size), Image.BICUBIC),
transforms.CenterCrop(opt.img_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
You will have to define folder_path
to point to the directory where you placed the images.
You will also have to change:
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
to:
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
To produce RGB images instead of grayscale.
To iterate through the images you will have to change the training loop to:
for epoch in range(opt.n_epochs):
for i, imgs in enumerate(dataloader):
from pytorch-gan.
@jgmeyerucsd did you manage to solve that issue?
from pytorch-gan.
@eriklindernoren you are so great!Thank you very much!
from pytorch-gan.
I tried to use this fix with a directory containing over 20,000 .png images according to:
f```
older_path = "./molecules"
transforms_ = [ transforms.Resize(300), Image.BICUBIC,
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
batch_size=64, shuffle=True, num_workers=8)
and I get the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-19-c503003ae828> in <module>()
1 batches_done = 0
2 for epoch in range(200):
----> 3 for i, imgs in enumerate(dataloader):
4
5 # Configure input
~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
334 self.reorder_dict[idx] = batch
335 continue
--> 336 return self._process_next_batch(batch)
337
338 next = __next__ # Python 2 compatibility
~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
355 self._put_indices()
356 if isinstance(batch, ExceptionWrapper):
--> 357 raise batch.exc_type(batch.exc_msg)
358 return batch
359
TypeError: Traceback (most recent call last):
File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "<ipython-input-5-3e5f53181d14>", line 17, in __getitem__
img = self.transform(img)
File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 49, in __call__
img = t(img)
TypeError: 'int' object is not callable
from pytorch-gan.
@jgmeyerucsd did you manage to solve that issue?
Try to remove the "Image.BICUBIC," from "transforms_ = [ transforms.Resize(300), Image.BICUBIC,
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]"
from pytorch-gan.
Related Issues (20)
- SRGAN - Adversarial Loss function
- plot gradient norm and weight clipping for each layer
- Gradient of a Discriminator in optimizing a Generator HOT 3
- when to save model for inference
- AAE in data augmentation
- buffer in the cyclegan HOT 1
- how test image on original size? (UNIT)
- Can I specify the number of images to be generated? HOT 1
- Query: DCGAN Saving Weights AND Parallelization
- Question about error for context encoder HOT 1
- Query: WGAN-GP FID SCORE (PyTorch) HOT 2
- Cannot download facades dataset for Pix2Pix HOT 4
- A little mistake in acgan HOT 1
- ESRGAN results HOT 2
- Quote GAN HOT 1
- clip_value in WGANGP not used
- ESRGAN datasets.py problem
- 我长期研究和改进GAN,如果对GAN或者深度学习感兴趣的可以联系我,联系方式,wechat: lovedaixiaobaby HOT 2
- Dataset Download Link was changed HOT 1
- Is it possible for the AI (GAN Generative Adversarial Network) to generate exploit codes (in this first version only python) automatically? SueGAN HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-gan.