Giter Club home page Giter Club logo

joytag's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

lun-4

joytag's Issues

On validation metrics and thresholds

First of all, nice job!

I noticed in the validation arena you're using my suggested thresholds for my models, and a "default" one for yours.
That's doing your work a disservice.
I think a fairer way to compare the models would be to try and find some fixed performance point, and see how the other metrics fare.

For my models, for example, I used to choose (by bisection) a threshold where micro averaged recall and precision matched: if both were higher than the last model then I had a better model.
You could do the same, or bisect towards a threshold that gives a desired precision and evaluate recall for example.
This also has the side effect of being more fair to augmentations like mixup, that skew predictions confidence towards lower values.

If I may go on a slight tangent about the discrepancy between my stated scores and the ones in the Arena: I used to use micro averaging, while you're calculating macro averages. Definitely keep using macro averaging for the metrics, I started using it too in my newer codebase over at https://github.com/SmilingWolf/JAX-CV (posting the repo in case you consider using it if you decide to apply to TRC).

Questions about choice of tags

I see you have commission and skeb_commission in your tag list. It feels odd to me how you can know if a drawing is a commission by just looking at it.

Versions of the dependencies?

Please add python version + requirements.txt or the spec file for whatever virtual environment management tool you are using.

This will help ensure the code still runs half a year from now, when maybe you're not committing as often etc.

(I don't see it listed anywhere, feel free to correct me if I'm wrong.)

EDIT:

Also, btw, example from readme crashes at the moment with

RuntimeError: Input type (float) and bias type (c10::Half) should be the same

which can be fixed by adding

def predict(image: Image.Image):
    image_tensor = prepare_image(image, model.image_size)
    batch = {
        'image': image_tensor.unsqueeze(0),
    }

    batch['image'] = batch['image'].to('cuda') # THIS here

Inspired by

batch['image'] = batch['image'].to('cuda')
# Forward
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
preds = model(batch)
predictions = torch.sigmoid(preds['tags'].to(torch.float32)).detach().cpu()
assert len(predictions.shape) == 2 and predictions.shape[0] == len(batch['image']) and predictions.shape[1] == len(model_tags) and predictions.dtype == torch.float32

EDIT 2:

My micromamba.yaml, if anyone is curious
name: joytag
channels:
- conda-forge
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- aiohttp=3.9.1
- aiosignal=1.3.1
- attrs=23.1.0
- aws-c-auth=0.7.8
- aws-c-cal=0.6.9
- aws-c-common=0.9.10
- aws-c-compression=0.2.17
- aws-c-event-stream=0.3.2
- aws-c-http=0.7.15
- aws-c-io=0.13.36
- aws-c-mqtt=0.10.0
- aws-c-s3=0.4.6
- aws-c-sdkutils=0.1.13
- aws-checksums=0.1.17
- aws-crt-cpp=0.25.0
- aws-sdk-cpp=1.11.210
- brotli-python=1.1.0
- bzip2=1.0.8
- c-ares=1.24.0
- ca-certificates=2023.11.17
- certifi=2023.11.17
- charset-normalizer=3.3.2
- colorama=0.4.6
- cuda-cudart=12.2.140
- cuda-cudart_linux-64=12.2.140
- cuda-nvrtc=12.2.140
- cuda-nvtx=12.2.140
- cuda-version=12.2
- cudnn=8.8.0.121
- datasets=2.14.4
- dill=0.3.7
- einops=0.7.0
- filelock=3.13.1
- freetype=2.12.1
- frozenlist=1.4.1
- fsspec=2023.12.2
- gflags=2.2.2
- glog=0.6.0
- gmp=6.3.0
- gmpy2=2.1.2
- huggingface_hub=0.20.0
- icu=73.2
- idna=3.6
- importlib-metadata=7.0.1
- jinja2=3.1.2
- keyutils=1.6.1
- krb5=1.21.2
- lcms2=2.16
- ld_impl_linux-64=2.40
- lerc=4.0.0
- libabseil=20230802.1
- libarrow=14.0.2
- libarrow-acero=14.0.2
- libarrow-dataset=14.0.2
- libarrow-flight=14.0.2
- libarrow-flight-sql=14.0.2
- libarrow-gandiva=14.0.2
- libarrow-substrait=14.0.2
- libblas=3.9.0
- libbrotlicommon=1.1.0
- libbrotlidec=1.1.0
- libbrotlienc=1.1.0
- libcblas=3.9.0
- libcrc32c=1.1.2
- libcublas=12.2.5.6
- libcufft=11.0.8.103
- libcurand=10.3.3.141
- libcurl=8.5.0
- libcusolver=11.5.2.141
- libcusparse=12.1.2.141
- libdeflate=1.19
- libedit=3.1.20191231
- libev=4.33
- libevent=2.1.12
- libexpat=2.5.0
- libffi=3.4.2
- libgcc-ng=13.2.0
- libgfortran-ng=13.2.0
- libgfortran5=13.2.0
- libgoogle-cloud=2.12.0
- libgrpc=1.59.3
- libhwloc=2.9.3
- libiconv=1.17
- libjpeg-turbo=3.0.0
- liblapack=3.9.0
- libllvm15=15.0.7
- libmagma=2.7.2
- libmagma_sparse=2.7.2
- libnghttp2=1.58.0
- libnl=3.9.0
- libnsl=2.0.1
- libnuma=2.0.16
- libnvjitlink=12.2.140
- libopenblas=0.3.25
- libparquet=14.0.2
- libpng=1.6.39
- libprotobuf=4.24.4
- libre2-11=2023.06.02
- libsqlite=3.44.2
- libssh2=1.11.0
- libstdcxx-ng=13.2.0
- libthrift=0.19.0
- libtiff=4.6.0
- libutf8proc=2.8.0
- libuuid=2.38.1
- libuv=1.46.0
- libwebp-base=1.3.2
- libxcb=1.15
- libxcrypt=4.4.36
- libxml2=2.11.6
- libzlib=1.2.13
- llvm-openmp=17.0.6
- lz4-c=1.9.4
- magma=2.7.2
- markupsafe=2.1.3
- mkl=2022.2.1
- mpc=1.3.1
- mpfr=4.2.1
- mpmath=1.3.0
- multidict=6.0.4
- multiprocess=0.70.15
- nccl=2.19.4.1
- ncurses=6.4
- networkx=3.2.1
- numpy=1.26.2
- openjpeg=2.5.0
- openssl=3.2.0
- orc=1.9.2
- packaging=23.2
- pandas=2.1.4
- pillow=10.1.0
- pip=23.3.2
- pthread-stubs=0.4
- pyarrow=14.0.2
- pysocks=1.7.1
- python=3.11.7
- python-dateutil=2.8.2
- python-tzdata=2023.3
- python-xxhash=3.4.1
- python_abi=3.11
- pytorch=2.1.0
- pytorch-gpu=2.1.0
- pytz=2023.3.post1
- pyyaml=6.0.1
- rdma-core=49.0
- re2=2023.06.02
- readline=8.2
- regex=2023.10.3
- requests=2.31.0
- s2n=1.4.1
- safetensors=0.3.3
- setuptools=68.2.2
- six=1.16.0
- sleef=3.5.1
- snappy=1.1.10
- sympy=1.12
- tbb=2021.11.0
- tk=8.6.13
- tokenizers=0.15.0
- torchvision=0.16.1
- tqdm=4.66.1
- transformers=4.36.2
- typing-extensions=4.9.0
- typing_extensions=4.9.0
- tzdata=2023d
- ucx=1.15.0
- urllib3=2.1.0
- wheel=0.42.0
- xorg-libxau=1.0.11
- xorg-libxdmcp=1.1.3
- xxhash=0.8.2
- xz=5.2.6
- yaml=0.2.5
- yarl=1.9.3
- zipp=3.17.0
- zstd=1.5.5

This is probably excessive, as the only thing that didn't work out of the box was python 3.12 (had to use 3.11), the rest could use latest available versions.

Cool project, btw.

[Discussion] Comparison with Danbooru interrogator in SD Automatic1111

Hello, thank you for sharing this model.

I did a quick and naive check between this and the Danbooru Interrogator in Automatic1111's webui and compared with the actual tags. The test took the 100 most recent posts from Danbooru with a success rate of 88/88. (12 images didn't have proper url to download).

These are my current observations:

  • JoyTag's model has a much higher similarity rate (true positive) to the actual tags.
  • JoyTag's model has a much lower incorrect tag prediction rate (false positive) compared to the SD Interrogator.
  • Both JoyTag and Interrogator model miss tags, with JoyTag missing less (false negative).

I'm looking forward to see if this can be integrated into the webui or retrained for even more tags!

Note:

  • I didn't do any in-depth study on the Interrogator before.
  • True Negative is 0 because there's nothing sensible to check with.
  • This uses the actual tags as the ground truth, although there could always be mistakes/missing tags/uneven tag distribution.

Settings: Threshold 0.5

chart-100im-0 5thresh-v2


I just noticed that the threshold on the doc was 0.4. Here, I ran the same code but with the new threshold. The images pulled may not be the same. Success rate 88/88 (12 failed to download).

Observation:

  • It has a much accurate prediction (true positive) score, however, it hallucinates more (false positive).

chart-100im-0 4thresh-v2

Edit: Updated charts to reflect fixed Interrogator code. Cleaning tags was necessary.

ONNX format model?

thank you for this model! May I ask if you could possibly provide an ONNX format for the model? I want to try and use it in extensions for stable diffusion (e.g. comfyUI) and have been unable to figure out how to convert it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.