Giter Club home page Giter Club logo

Comments (4)

mehditlili avatar mehditlili commented on May 21, 2024 2

I now have two classes + background, my mask images have 3 values, 0 for background, 1 for the first class and 2 for the second class. however when I set n_classes to 2 this happens:

Starting epoch 1/1000.
/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py:2016: UserWarning: Using a target size (torch.Size([921600])) that is different to the input size (torch.Size([1843200])) is deprecated. Please ensure they have the same size.
  "Please ensure they have the same size.".format(target.size(), input.size()))
Traceback (most recent call last):
  File "train.py", line 139, in <module>
    img_scale=args.scale)
  File "train.py", line 80, in train_net
    loss = criterion(masks_probs_flat, true_masks_flat)
  File "/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 504, in forward
    return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "/home/vision/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py", line 2019, in binary_cross_entropy
    "!= input nelement ({})".format(target.numel(), input.numel()))
ValueError: Target and input must have the same number of elements. target nelement (921600) != input nelement (1843200)

Not sure whad I should do

from pytorch-unet.

milesial avatar milesial commented on May 21, 2024 2

You have to separate your target masks into 2 layers with values 0 or 1 (black and white). So you would have a target tensor with 2 channels:

  • mask == 1 (first class) -> to int values (0 or 1)
  • mask == 2 (second class) -> to int values (0 or 1)

After this processing your mask tensor should be HxWxC, with C=2 for you.
In the error you can see that pytorch expects twice as many values in the target mask, that should come from these two channels.

from pytorch-unet.

mehditlili avatar mehditlili commented on May 21, 2024 1

ah I see, so I did understand it correctly at the beginning

it seems that I need a multichannel image with a channel for each class, can you confirm that?

Thank you

from pytorch-unet.

milesial avatar milesial commented on May 21, 2024

As I see it, with only 0 and 1 in the masks, you only have 1 class: the class with the ones. For the kaggle dataset the class was 'car'. The zeros do not correspond to any class. But you could see the zeros as the class 'not a car' as well, so there would be 2 "classes".

It just depends on how you define your classes and your problem. If you would have two types of objects to segment, you might have an additional id that correspond to 'nothing', here that id is 0.

from pytorch-unet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.