Giter Club home page Giter Club logo

triplet-network-pytorch's Introduction

A PyTorch Implementation for Triplet Networks

This repository contains a PyTorch implementation for triplet networks.

The code provides two different ways to load triplets for the network. First, it contain a simple MNIST Loader that generates triplets from the MNIST class labels. Second, this repository provides a Triplet Loader that loads images from folders, provided a list of triplets.

Example usage:

$ python train.py

Tracking experiments with Visdom

This repository allows to track experiments with visdom. You can use the VisdomLinePlotter to plot training progress.

If this implementation is useful to you and your project, please also consider to cite or acknowledge this code repository.

triplet-network-pytorch's People

Contributors

andreasveit avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

triplet-network-pytorch's Issues

BrokenPipeError:[Errno 32] Broken pipe.

Hi, I use Pytorch to run a triplet network(GPU), but when I got data , there was always a error: BrokenPipeError:[Errno 32] Broken pipe.

I thought it may be something wrong in the following codes:

for batch_idx, (data1, data2, data3) in enumerate(test_loader):
if args.cuda:
data1, data2, data3 = data1.cuda(), data2.cuda(), data3.cuda()
data1, data2, data3 = Variable(data1), Variable(data2), Variable(data3)

Can you give me some suggestions? Thank you so much.

A bug when try to modify your code.

Since I'm new to PyTorch, when I add transforms.CenterCrop(size) to get my dataloader, this error occur

  File "~/triplet-network-pytorch/triplet_data_loader.py", line 73, in __getitem__
    img1 = self.transform(img1)
  File "build/bdist.linux-x86_64/egg/torchvision/transforms.py", line 29, in __call__
    img = t(img)
  File "build/bdist.linux-x86_64/egg/torchvision/transforms.py", line 156, in __call__
    w, h = img.size
TypeError: 'builtin_function_or_method' object is not iterable

If I add transforms.CenterCrop(28) to MNIST_t(*),

train_loader = torch.utils.data.DataLoader(
        MNIST_t('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           # transforms.CenterCrop(28)
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)

It also occur such error. Could you please help me solve it?

Data Load

Hi!

Thank you so much for the code.

I have a dataset for an example: Datasets A,
DatasetA has 50 classes (50 folders). In each folder contains same class images. How can I load the data?

list index out of range

Hi, there is some new error here
how can I handle this error?
line 47, in init
triplets.append((int(line.split()[0]), int(line.split()[1]), int(line.split()[2]))) # anchor, close, far

IndexError: list index out of range

about why dista should larger than distb.

Excuse me.
I have a question that why dista should be larger than distb in --train.py --train() function
I mean that distance between ref and pos should be smaller than ref and neg.
thanks very much!

Invalid Index of 0-dim tensor

losses.update(loss_triplet.data[0], data1.size(0))

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

No Online Mining

Hi! Great effort on the implementation.
One short coming for this could be that currently this uses Offline Mining.
We could also make the training process more robust by adding Online Mining for the DataLoader.

I would be happy to help with that.

Cheers! :)

Accuracy is always zero

screen shot 2018-06-21 at 1 55 46 pm

Anything I am missing. I just did a clone and ran train.py blindly without any changes. Accuracy is zero unlike the one shown in README

TypeError: a bytes-like object is required, not 'str'

Processing Triplet Generation ...
Traceback (most recent call last):
File "train.py", line 257, in
main()
File "train.py", line 61, in main
transforms.Normalize((0.1307,), (0.3081,))
File "/home/qwe/triplet-network-pytorch/triplet_mnist_loader.py", line 44, in init
self.make_triplet_list(n_train_triplets)
File "/home/qwe/triplet-network-pytorch/triplet_mnist_loader.py", line 164, in make_triplet_list
writer.writerows(triplets)
TypeError: a bytes-like object is required, not 'str'

Acc = 0

I Finally got where I am wrong. the definition of accuracy seems not compatible with my python version

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.