Giter Club home page Giter Club logo

pytorch-dann's Introduction

NaJaeMin92's GitHub stats

FixBi CoVi Dual-teacher

pytorch-dann's People

Contributors

najaemin92 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-dann's Issues

Why this model's loss not care about the taget dataset's label

1.Classification loss

class_pred = classifier(source_feature)
class_loss = classifier_criterion(class_pred, source_label)

2. Domain loss

domain_pred = discriminator(combined_feature, alpha)

domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor)
domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor)
domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda()
domain_loss = discriminator_criterion(domain_pred, domain_combined_label)

total_loss = class_loss + domain_loss
total_loss.backward()

i've seen many implementation of DANN...but every of it just get the loss with three parts .
could you tell me why this model is not care about the loss of taget dataset's label??

Regarding softmax layer

The softmax layer makes models hard to train. If you faced the same problem try to remove the last softmax layer in Classifier and Discriminator.

RuntimeError: Output shape doesn't match broadcast shape.

Setup:

$ conda list
# packages in environment at C:\Users\$username\.conda\envs\pytorch1:
#
# Name                    Version                   Build  Channel
blas                      1.0                         mkl
brotli                    1.0.9                ha925a31_2
brotlipy                  0.7.0           py39h2bbff1b_1003
ca-certificates           2022.3.29            haa95532_0
certifi                   2021.10.8        py39haa95532_2
cffi                      1.15.0           py39h2bbff1b_1
charset-normalizer        2.0.4              pyhd3eb1b0_0
cryptography              36.0.0           py39h21b164f_0
cudatoolkit               11.3.1               h59b6b97_2
cycler                    0.11.0             pyhd3eb1b0_0
fonttools                 4.25.0             pyhd3eb1b0_0
freetype                  2.10.4               hd328e21_0
icc_rt                    2019.0.0             h0cc432a_1
icu                       58.2                 ha925a31_3
idna                      3.3                pyhd3eb1b0_0
intel-openmp              2021.4.0          haa95532_3556
joblib                    1.1.0              pyhd3eb1b0_0
jpeg                      9d                   h2bbff1b_0
kiwisolver                1.3.2            py39hd77b12b_0
libpng                    1.6.37               h2a8f88b_0
libtiff                   4.2.0                hd0e1b90_0
libuv                     1.40.0               he774522_0
libwebp                   1.2.2                h2bbff1b_0
lz4-c                     1.9.3                h2bbff1b_1
matplotlib                3.5.1            py39haa95532_1
matplotlib-base           3.5.1            py39hd77b12b_1
mkl                       2021.4.0           haa95532_640
mkl-service               2.4.0            py39h2bbff1b_0
mkl_fft                   1.3.1            py39h277e83a_0
mkl_random                1.2.2            py39hf11a4ad_0
munkres                   1.1.4                      py_0
numpy                     1.21.5           py39ha4e8547_0
numpy-base                1.21.5           py39hc2deb75_0
openssl                   1.1.1n               h2bbff1b_0
packaging                 21.3               pyhd3eb1b0_0
pillow                    9.0.1            py39hdc2b20a_0
pip                       21.2.4           py39haa95532_0
pycparser                 2.21               pyhd3eb1b0_0
pyopenssl                 22.0.0             pyhd3eb1b0_0
pyparsing                 3.0.4              pyhd3eb1b0_0
pyqt                      5.9.2            py39hd77b12b_6
pysocks                   1.7.1            py39haa95532_0
python                    3.9.11               h6244533_1
python-dateutil           2.8.2              pyhd3eb1b0_0
pytorch                   1.11.0          py3.9_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
qt                        5.9.7            vc14h73c81de_0
requests                  2.27.1             pyhd3eb1b0_0
scikit-learn              1.0.2            py39hf11a4ad_1
scipy                     1.7.3            py39h0a974cb_0
setuptools                58.0.4           py39haa95532_0
sip                       4.19.13          py39hd77b12b_0
six                       1.16.0             pyhd3eb1b0_1
sqlite                    3.38.0               h2bbff1b_0
threadpoolctl             2.2.0              pyh0d69192_0
tk                        8.6.11               h2bbff1b_0
torchaudio                0.11.0               py39_cu113    pytorch
torchvision               0.12.0               py39_cu113    pytorch
tornado                   6.1              py39h2bbff1b_0
typing_extensions         4.1.1              pyh06a4308_0
tzdata                    2021e                hda174b7_0
urllib3                   1.26.8             pyhd3eb1b0_0
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
wheel                     0.37.1             pyhd3eb1b0_0
win_inet_pton             1.1.0            py39haa95532_0
wincertstore              0.2              py39haa95532_2
xz                        5.2.5                h62dcd97_0
zlib                      1.2.11               hbd8134f_5
zstd                      1.4.9                h19a0ad4_0

Error received;

Running GPU : 0
Source-only training
Epoch : 0
Traceback (most recent call last):
  File "E:\GitProjects\PyTorchTests\MNIST-m_DANN\pytorch_DANN\main.py", line 30, in <module>
    main()
  File "E:\GitProjects\PyTorchTests\MNIST-m_DANN\pytorch_DANN\main.py", line 22, in main
    train.source_only(encoder, classifier, source_train_loader, target_train_loader, save_name)
  File "E:\GitProjects\PyTorchTests\MNIST-m_DANN\pytorch_DANN\train.py", line 34, in source_only
    for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
    data = self._next_data()
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torchvision\datasets\mnist.py", line 145, in __getitem__
    img = self.transform(img)
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 95, in __call__
    img = t(img)
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\transforms.py", line 270, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "C:\Users\$username\.conda\envs\pytorch1\lib\site-packages\torchvision\transforms\functional.py", line 363, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

d

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.