Giter Club home page Giter Club logo

csd's Introduction

This folder contains code accompanying the submission: Efficient Domain Generalization via Common-Specific Low-Rank Decomposition (ICML 2020).

Our proposed technique: Common Specific Decomposition (CSD) is extremely simple and can be easily incorporated in to your codestack by replacing the final layer with CSD. You may find the this isolated code nuggets of CSD for TensorFlow and PyTorch handy.

However, if you wish to reproduce numbers from our paper, we highly encourage that you use our codestack.

The folders are organized into

  1. rotation: Rotation tasks on MNIST and Fashion-MNIST
  2. hw: Hand-written tasks on LipitK and Nepali Character recognition datasets
  3. pacs: PACS evaluation with ResNet18 and AlexNet
  4. speech: Speech utterance classification evaluation

To be able to run all the experiments, you need the following packages for Python3.6

  1. Tensorflow <= 1.16 >1.13 for rotation, speech and hand-written experiments
  2. PyTorch >=1.0 for PACS
  3. Keras for loading datasets in rotation tasks.

The data for hand-written and rotation tasks are either provided or scripted to automatically download, however the PACS and speech dataset are to be downloaded and configured in order to run the provided code.

The following sections give more specific instructions on how to run the code.

Rotation tasks

CSD:

python main.py --dataset mnist --classifier mos

ERM:

python main.py --dataset mnist --classifier simple

Hand-written tasks

LipitK Task

python hw_train_and_test.py --dataset lipitk --num_train <number of train domains>

NHCD Task

python hw_train_and_test.py --dataset nhcd --lr 1e-4

Use flags: –simple or –cg for ERM or CG baselines.

PACS

CSD is coded in to the implementation of JigenDG, a previous domain generalizing approach.

Use the following steps in order to run experiments.

  1. Download and configure the PACS dataset such that all the paths in pacs/data/txt_lists are appropriate.
  2. Use pacs/run.sh either to run PACS evaluation using AlexNet or ResNet-18.

Speech task

You need to download Speech dataset and extract it in to speech/speech_dataset/ folder. It can be downloaded from: http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz

You can then CSD using the follwing command.

$ cd speech;
$ python train.py --training_percentage <num train domains> --train_dir=<checkpoints folder> --model mos2 --seed 0 --learning_rate 2e-3 --how_many_epochs 500 --lmbda 0.5 --num_uids=<2(K)>

Feel free to open an issue or write to me if you find something missing or confusing.

csd's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

csd's Issues

One question

Hi, when I run "python hw_train_and_test.py --dataset lipitk --num_train 50
One error occurs:
FileNotFoundError: [Errno 2] No such file or directory: 'data/lipitk.pkl'
It seems the data for hand-written tasks is not automatically download
DATA_FLDR = "data/hpl-devnagari-iso-char-offline"
And I do not see the "data/hpl-devnagari-iso-char-offline" in the hw

Error when running pacs resnet18 scripts

When I run the resnet scripts use your run.sh, A IndexError occured. Could you please rerun this scripts and check it?

python train_csd.py --train_all --min_scale 0.8 --max_scale 1.0 --random_horiz_flip 0.5 --jitter 0.4 --tile_random_grayscale 0.1 --source photo cartoon sketch --target art_painting --bias_whole_image 0.9 --image_size 222 --seed 0

File "train_csd.py", line 164, in
main()
File "train_csd.py", line 159, in main
trainer.do_training()
File "train_csd.py", line 146, in do_training
self._do_epoch()
File "train_csd.py", line 79, in _do_epoch
for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
File "/root/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in next
data = self._next_data()
File "/root/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/root/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/root/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/userhome/code/DG/CSD/pacs/data/concat_dataset.py", line 48, in getitem
return self.datasets[dataset_idx][sample_idx], dataset_idx
File "/userhome/code/DG/CSD/pacs/data/JigsawLoader.py", line 98, in getitem
data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)]
File "/userhome/code/DG/CSD/pacs/data/JigsawLoader.py", line 98, in
data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)]
IndexError: list index out of range

One help

Hi!
When I run CG on nchd dataset, and I download the dataset and put them in the directory, one error occurs:
FileNotFoundError: [Errno 2] No such file or directory: 'data/nhcd/nhcd/consonants/all_domains.txt'
where is the all_domains.txt?

About domains in code

'''domains (tensor): tf tensor with domain index of dim 1 -- set to all zeros when testing
domains = torch.nn.functional.one_hot(domains, num_domains) '''

how should i set this 'domains' parameter, i have num_domains = 3 in my code ,if i want to deploy csd in my code, what size should this domains get ? thanks a lot.

When run with resnet-18, it will report error

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/jintao/anaconda3/envs/py3/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/jintao/anaconda3/envs/py3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/jintao/anaconda3/envs/py3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/jintao/CSD-master/pacs/data/concat_dataset.py", line 48, in getitem
return self.datasets[dataset_idx][sample_idx], dataset_idx
File "/home/jintao/CSD-master/pacs/data/JigsawLoader.py", line 102, in getitem
data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)]
File "/home/jintao/CSD-master/pacs/data/JigsawLoader.py", line 102, in
data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)]
IndexError: list index out of range

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.