One useful trick for data cleaning is taking a trained Discriminator and using it to find the 'worst' samples and either manually reviewing them or automatically deleting them.
I modified the training script to support this by ranking a folder dataset using a D model. The default dataloader doesn't preserve image paths, so an additional dataset has to be added:
diff --git a/data_loader.py b/data_loader.py
index 736362c..c4528a8 100755
--- a/data_loader.py
+++ b/data_loader.py
@@ -29,7 +29,7 @@ class Data_Loader():
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
return dataset
-
+^M
def load_imagenet(self):
transforms = self.transform(True, True, True, True)
dataset = dsets.ImageFolder(self.path+'/imagenet', transform=transforms)
@@ -42,9 +42,15 @@ class Data_Loader():
def load_off(self):
transforms = self.transform(True, True, True, False)
dataset = dsets.ImageFolder(self.path, transform=transforms)
return dataset
+ def load_rank(self):^M
+ transforms = self.transform(True, True, True, False)^M
+ dataset = ImageFolderWithPaths(self.path, transform=transforms)^M
+ return dataset^M
+^M
def loader(self):
if self.dataset == 'lsun':
dataset = self.load_lsun()
@@ -54,6 +60,8 @@ class Data_Loader():
dataset = self.load_celeb()
elif self.dataset == 'off':
dataset = self.load_off()
+ elif self.dataset == 'rank':^M
+ dataset = self.load_rank()^M
print('dataset',len(dataset))
loader = torch.utils.data.DataLoader(dataset=dataset,
@@ -63,3 +71,18 @@ class Data_Loader():
drop_last=True)
return loader
+^M
+class ImageFolderWithPaths(dsets.ImageFolder):^M
+ """Custom dataset that includes image file paths. Extends^M
+ torchvision.datasets.ImageFolder^M
+ """^M
+^M
+ # override the __getitem__ method. this is the method dataloader calls^M
+ def __getitem__(self, index):^M
+ # this is what ImageFolder normally returns^M
+ original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)^M
+ # the image file path^M
+ path = self.imgs[index][0]^M
+ # make a new tuple that includes original and the path^M
+ tuple_with_path = (original_tuple + (path,))^M
+ return tuple_with_path^M
diff --git a/parameter.py b/parameter.py
index 0b59c5a..1191208 100755
--- a/parameter.py
+++ b/parameter.py
@@ -37,7 +37,7 @@ def get_parameters():
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--gpus', type=str, default='0', help='gpuids eg: 0,1,2,3 --parallel True ')
- parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off'])
+ parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off', 'rank'])^M
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
# Path
import torch
import torch.nn as nn
from model_resnet import Discriminator
# from utils import *
from parameter import *
from data_loader import Data_Loader
from torch.backends import cudnn
import os
class Trainer(object):
def __init__(self, data_loader, config):
# Data loaders
self.data_loader = data_loader
# exact model and loss
self.model = config.model
# Model hyper-parameters
self.imsize = config.imsize
self.parallel = config.parallel
self.gpus = config.gpus
self.batch_size = config.batch_size
self.num_workers = config.num_workers
self.pretrained_model = config.pretrained_model
self.dataset = config.dataset
self.image_path = config.image_path
self.version = config.version
self.n_class = 1000 # config.n_class TODO
self.chn = config.chn
# Path
self.model_save_path = os.path.join(config.model_save_path, self.version)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.build_model()
# Start with trained model
self.load_pretrained_model()
self.train()
def train(self):
self.D.train()
# Data iterator
data_iter = iter(self.data_loader)
total_steps = self.data_loader.__len__()
for step in range(0, total_steps):
real_images, real_labels, real_paths = next(data_iter)
real_labels = real_labels.to(self.device)
real_images = real_images.to(self.device)
d_out_real = self.D(real_images, real_labels)
rankings = d_out_real.data.tolist()
for i in range(0, len(real_paths)):
print(real_paths[i], rankings[i])
def load_pretrained_model(self):
self.D.load_state_dict(torch.load(os.path.join(
self.model_save_path, '{}_D.pth'.format(self.pretrained_model))))
def build_model(self):
# code_dim=100, n_class=1000
self.D = Discriminator(self.n_class, chn=self.chn).to(self.device)
if self.parallel:
gpus = [int(i) for i in self.gpus.split(',')]
self.D = nn.DataParallel(self.D, device_ids=gpus)
def main(config):
# For fast training
cudnn.benchmark = True
# Data loader
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=False)
Trainer(data_loader.loader(), config)
if __name__ == '__main__':
config = get_parameters()
# print(config)
main(config)
Some cleaner built-in support would be good.