Giter Club home page Giter Club logo

d4m's Introduction

πŸ’Ύ D4M: Dataset Distillation via Disentangled Diffusion Model

πŸ’₯ Stellar Features

🎯 Distilling Dataset in an Optimization-Free manner.
🎯 The distillation process is Architecture-Free. (Getting over the Cross-Architecture problem.)
🎯 Distilling large-scale datasets (ImageNet-1K) efficiently.
🎯 The distilled datasets are high-quality and versatile.

πŸ“š Introduction

Dataset distillation offers a lightweight synthetic dataset for fast network training with promising test accuracy. We advocate for designing an economical dataset distillation framework that is independent of the matching architectures. With empirical observations, we argue that constraining the consistency of the real and synthetic image spaces will enhance the cross-architecture generalization. Motivated by this, we introduce Dataset Distillation via Disentangled Diffusion Model (D4M), an efficient framework for dataset distillation. Compared to architecture-dependent methods, D4M employs latent diffusion model to guarantee consistency and incorporates label information into category prototypes. The distilled datasets are versatile, eliminating the need for repeated generation of distinct datasets for various architectures. Through comprehensive experiments, D4M demonstrates superior performance and robust generalization, surpassing the SOTA methods across most aspects.

method

Overview of D4M. For more details, please see our paper.

πŸ”§ Quick Start

Create environment

  • Python >=3.9
  • Pytorch >= 1.12.1
  • Torchvision >= 0.13.1

Install Diffusers Library

You can install or upgrade the latest version of Diffusers library according to this page.

Modify Diffusers Library

Step 1: Copy the pipeline scripts (generate latents pipeline and synthesis images pipeline) into the path of Diffusers Library: diffusers/src/diffusers/pipelines/stable_diffusion.

Step 2: Modify Diffusers source code according to scripts/README.md.

Generate Prototypes

cd distillation
sh gen_prototype_imgnt.sh

Synthesis Images

cd distillation
sh gen_syn_image_imgnt.sh

Actually, if you don't need the JSON files (prototype) for exploration, you could combine the generate and synthesis processes into one, skipping the I/O steps.

Training-Time Matching (TTM)

cd matching
sh matching.sh

Validate

cd validate
sh train_FKD.sh

✨ Qualitative results

Compare to others

ImageNet-1K Results (Top: D4M, Bottom: SRe2L)
imagenet-1k results
Tiny-ImageNet Results (Top: D4M, Bottom: SRe2L)
tiny-imagnet results
CIFAR-10 Results (Top: D4M, Bottom: MTT)
cifar-10 results
CIFAR-100 Results (Top: D4M, Bottom: MTT)
cifar-100 results

Semantic Information

Distilled data within one class (Top: D4M, Bottom: SRe2L)
semantic richness results

For more qualitative results, please see the supplementary in our paper.

πŸ“Š Quantitative results

Results on Large-Scale datasets
semantic richness results

πŸ‘πŸ» Acknowledgments

Our code is developed based on the following codebases, thanks for sharing!

πŸ“– Citation

If you find this work helpful, please cite:

@InProceedings{Su_2024_CVPR,
    author    = {Su, Duo and Hou, Junjie and Gao, Weizhi and Tian, Yingjie and Tang, Bowen},
    title     = {D{\textasciicircum}4M: Dataset Distillation via Disentangled Diffusion Model},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {5809-5818}
}

d4m's People

Contributors

suduo94 avatar

Stargazers

 avatar Jeff Carpenter avatar  avatar Yiran Guan avatar liuliuliuliu avatar Zhongwei Luo avatar  avatar  avatar _FengX1n avatar Eros avatar  avatar Frank Wang avatar Guo Xun avatar Zidu Wang avatar  avatar  avatar Yanyu Xu avatar Zhenghao Zhao avatar Junjie avatar

Watchers

Kostas Georgiou avatar  avatar Junjie avatar

Forkers

yanshimmer

d4m's Issues

Unable to reproduce the results?

Hi, I tried to reproduce the results following the steps in readme. The generated images look good and realistic to me, but I only got a top-1 err near 90 under the resnet18 and 10ipc setting (which should be 100 - 27.9 = 72.1 in paper?). The only difference is that I'm training on 1 single A100 gpu instead of 8 (same batchsize, same learning rate) but I guess that should make no difference?

I'm wondering is it possible to also provide a env/docker file, or maybe the generated dataset and the features for kl loss? I'm using the latest pytorch/diffusers library. Since your project requires us to directly modify the source code of the above library, I'm not sure if I have done it correctly.

My training log:

Epoch: 196
TRAIN Iter 196: lr = 0.000001,  loss = 0.001393,        Top-1 err = 99.650000,  Top-5 err = 98.910000,  train_time = 36.656602

Epoch: 197
TRAIN Iter 197: lr = 0.000001,  loss = 0.001418,        Top-1 err = 99.740000,  Top-5 err = 98.930000,  train_time = 35.769536

Epoch: 198
^@TRAIN Iter 198: lr = 0.000000,        loss = 0.001397,        Top-1 err = 99.690000,  Top-5 err = 99.030000,    train_time = 36.054903

Epoch: 199
TRAIN Iter 199: lr = 0.000000,  loss = 0.001414,        Top-1 err = 99.700000,  Top-5 err = 99.040000,  train_time = 36.111678
TEST Iter 199: loss = 3.881970, Top-1 err = 81.090000,  Top-5 err = 55.930000,  val_time = 70.990983

About the StableDiffusionLatents2ImgPipeline

Hello, Duo. This work is very interesting.
I'm trying to follow it, but I've run into some issues.
When I run gen_syn_image_hjj.py, I found in the diffusers modules does not provide information about StableDiffusionLatents2ImgPipeline pipineline.
I read the link (https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) can't find information about StableDiffusionLatents2ImgPipeline.
I hope you can provide the solution of this part.
Thank you very much.

Regarding the StableDiffusionGenLatentsPipeline

Hi team,

Good work but I am coming across few issues while running the code.

I am trying to run the code for the( gen_prototype.py) script and upon running the script I get an error :

ImportError: cannot import name 'StableDiffusionGenLatentsPipeline' from 'diffusers'. Can you please explain and help me with the error?

Regarding the StableDiffusionGenLatentsPipeline

hi, I try to modify Diffusers source code according to scripts/README.md. However, it produces "ImportError: cannot import name 'StableDiffusionGenLatentsPipeline' from 'diffusers' " Could you please help me

distilled dataset?

Hi, Can you share the distilled imagenet dataset? This will be much more convenient if one only wants to try out the distilled version with some other networks.

Many thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.