Giter Club home page Giter Club logo

torchvahadane's Introduction

TorchVahadane

Vahadane stain normalization is being used extensively in Digital Pathology workflows to provide better generalization of Deep Learning models between cohorts.

The StainTools package has been one of the most used and most clear implementation of the Vahadane stain normalization. Unfortunately, StainTools can be slow when used on large images or on a large number of images.

This repository implements a GPU accelerated version of the Vahadane stain normalization using torch.

This repository provides a fully GPU based stain normalization workflow, useful in combination with cuCIM and a faster workflow using CPU based stain matrix estimation with accelerated stain concentration estimation.

For WSI workflows, a fixed target stain matrix can be set, eliminating the need for recalculating the stain matrix for every new image patch and making the transformation fully GPU based.

Screenshot

Benchmarks using LineProfiler show speed increase of TorchVahadane compared to StainTools.

Method fit [s] transform [s] total [s]
StainTools Vahadane 17.4 17.1        34.5
TorchVahadane 9.4 8.5 [1.5] 17.9
TorchVahadane ST 3.2 3.1 [1.5] 6.3

Brackets indicate the transformation speed when using a fixed stain matrix (see robust stain estimation). Measured using python 3.11.3 and spams 2.6.5.4.

The new TorchVahadane version fixes instabilities in the concentration estimation by replacing the ista algorithm with the iterative positive thresholding algorithm (IPTA), which uses the correct positivity constraint for the sparse regularization problem.

Usage

TorchVahadane can be employed as a drop-in replacement for StainTools. Per default, the TorchVahadaneNormalizer uses the cuda device and uses staintools based stain_matrix estimation (fastest and most robust approach). As StainTools is now a read-only repository, StainTools is integrated and not used as a dependency. The transform function also has a default parameter to pass-through the generated tissue mask for downstream tasks.

from torchvahadane import TorchVahadaneNormalizer
normalizer = TorchVahadaneNormalizer(device='cuda', staintools_estimate=True)
normalizer.fit(target)
img_normed = normalizer.transform(img)

# return tissue mask
img_normed, img_mask = normalizer.transform(img, return_mask=True)

Histogram matching

In practice, Vahadane normalization does not always transfer the saturation and contrast of the reference image to the source image, this retaining a domain shift between the target and source image.

This can be mitigated by using histogram matching (see skimage.exposure.match_histograms).

TorchVahadane implements masked histogram matching using torch, matching the cumulative density function of the histograms only on tissue pixels. This makes histogram matching suitable to work on histology tiles with non-tissue regions.

Screenshot

Histogram matching is integrated into the standard fit/transform pipeline and is enabled by setting correct_exposure to True.

from torchvahadane import TorchVahadaneNormalizer
normalizer = TorchVahadaneNormalizer(correct_exposure=True)
normalizer.fit(target)
normalizer.transform(img)

Masked histogram matching can also be used directly from the histogram_matching module.

Robust stain estimation of Whole Slide Images

TorchVahadane also supports the estimation of median stain intensities, as proposed by Vahadane et al. TorchVahadane samples the WSI over a grid of tiles amd returns the median stain instensities. Openslide is used as an optional dependency to extract the WSI tiles.

Setting a fixed stain_matrix skips the stain_matrix estimation of the source image in the transform, speeding up subsequent transformations.

from torchvahadane.wsi_util import estimate_median_matrix

stain_matrix = estimate_median_matrix(osh, normalizer, osh_level=0, tile_size=4096, num_workers=12)
normalizer.set_stain_matrix(stain_matrix)

Installation

TorchVahadane can be installed with pip using

git clone https://github.com/cwlkr/torchvahadane.git
cd torchvahadane
pip install .

or directly

pip install git+https://github.com/cwlkr/torchvahadane.git

Notes

Spams installation through pip throws more errors than not. Using conda's pre-compiled binaries might work best. Spams is not listed in the package requirements.

Openslide is not listed in the package requirements as it is considered optional.

Acknowledgments

Several lines of code in this repository are directly adapted from the following repositories I would like to credit with their excelent work!

StainTools
pytorch-lasso
torchstain

torchvahadane's People

Contributors

cielal avatar cwlkr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

cielal doukandy

torchvahadane's Issues

Potential nan in dictionary update steps caused by 0/0

Attempt fix in #1
In the update_dict method in https://github.com/cwlkr/torchvahadane/blob/main/torchvahadane/dict_learning.py#L100, when the atomic norm is too small, the corresponding atom is re-initialized with tensor.normal_.
However, if all numbers drawn are negative, the following in-place clamp will yield a zero vector, and in the in-place normalization
of dictionary[:, k] /= dictionary[:, k].norm() a nan vector will be generated and this will invalidate the whole computation procedure.

How to parallelize?

I'm sorry to bother you, but is it possible to provide a graphics card for parallel processing of multiple slices at the same time? I use the for loop and it takes too long. An unknown error occurred when I used the from multiprocessing import Pool module. So I don't have any other options. thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.