Giter Club home page Giter Club logo

Comments (5)

eriklindernoren avatar eriklindernoren commented on July 27, 2024 1

Hi,

You would have to save a lot of images of yourself to a folder and use something like this dataset loader.

import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self, folder_path, transforms_=None):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob('%s/*.*' % folder_path))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.files)

And in wgan_gp.py:

# Dataset loader
transforms_ = [ transforms.Resize(opt.img_size), Image.BICUBIC),
                transforms.CenterCrop(opt.img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
                        batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)

You will have to define folder_path to point to the directory where you placed the images.
You will also have to change:

parser.add_argument('--channels', type=int, default=1, help='number of image channels')

to:

parser.add_argument('--channels', type=int, default=3, help='number of image channels')

To produce RGB images instead of grayscale.

To iterate through the images you will have to change the training loop to:

for epoch in range(opt.n_epochs):
    for i, imgs in enumerate(dataloader):

from pytorch-gan.

michirup avatar michirup commented on July 27, 2024 1

@jgmeyerucsd did you manage to solve that issue?

from pytorch-gan.

fmbao avatar fmbao commented on July 27, 2024

@eriklindernoren you are so great!Thank you very much!

from pytorch-gan.

jgmeyerucsd avatar jgmeyerucsd commented on July 27, 2024

I tried to use this fix with a directory containing over 20,000 .png images according to:

f```
older_path = "./molecules"
transforms_ = [ transforms.Resize(300), Image.BICUBIC,
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
batch_size=64, shuffle=True, num_workers=8)


and I get the following error:


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-19-c503003ae828> in <module>()
      1 batches_done = 0
      2 for epoch in range(200):
----> 3     for i, imgs in enumerate(dataloader):
      4 
      5         # Configure input

~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    334                 self.reorder_dict[idx] = batch
    335                 continue
--> 336             return self._process_next_batch(batch)
    337 
    338     next = __next__  # Python 2 compatibility

~/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
    355         self._put_indices()
    356         if isinstance(batch, ExceptionWrapper):
--> 357             raise batch.exc_type(batch.exc_msg)
    358         return batch
    359 

TypeError: Traceback (most recent call last):
  File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-5-3e5f53181d14>", line 17, in __getitem__
    img = self.transform(img)
  File "/home/jgmeyer2/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 49, in __call__
    img = t(img)
TypeError: 'int' object is not callable


from pytorch-gan.

FT115 avatar FT115 commented on July 27, 2024

@jgmeyerucsd did you manage to solve that issue?
Try to remove the "Image.BICUBIC," from "transforms_ = [ transforms.Resize(300), Image.BICUBIC,
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]"

from pytorch-gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.