Giter Club home page Giter Club logo

lsd's Introduction

Local Shape Descriptors (for Neuron Segmentation)

This repository contains code to compute Local Shape Descriptors (LSDs) from an instance segmentation. LSDs can then be used during training as an auxiliary target, which we found to improve boundary prediction and therefore segmentation quality. Read more about it in our paper and/or blog post.

Paper Blog Post
Paper Blog post

Quick 2d Examples

Notebooks

Example networks & pipelines

Parallel processing


Cite:

@article{sheridan_local_2022,
	title = {Local shape descriptors for neuron segmentation},
	issn = {1548-7091, 1548-7105},
	url = {https://www.nature.com/articles/s41592-022-01711-z},
	doi = {10.1038/s41592-022-01711-z},
	urldate = {2023-01-12},
	journal = {Nature Methods},
	author = {Sheridan, Arlo and Nguyen, Tri M. and Deb, Diptodip and Lee, Wei-Chung Allen and Saalfeld, Stephan and Turaga, Srinivas C. and Manor, Uri and Funke, Jan},
	month = dec,
	year = {2022},
}

Notes:

  • Lsds can be installed via conda/mamba through conda-forge channel, e.g:
conda create -n lsd_env python=3
conda activate lsd_env
conda install lsds -c conda-forge

also via pip:

mamba create -n lsd_env python=3
mamba activate lsd_env
pip install lsds
  • conda repo, pypi page

  • Tested on Ubuntu 18.04 with Python 3.

  • This is not production level software and was developed in a pure research environment. Therefore some scripts may not work out of the box. For example, all paper networks were originally written using now deprecated tensorflow/cudnn versions and rely on an outdated singularity container. Because of this, the singularity image will not build from the current recipe - if replicating with the current implementations, please reach out for the singularity container (it is too large to upload here). Alternatively, consider reimplementing networks in pytorch (recommended - see Training).

  • Post-proccesing steps were designed for use with a specific cluster and will need to be tweaked for individual use cases. If the need / use increases then we will look into refactoring, packaging and distributing.

  • Currently, several post-processing scripts (e.g watershed) are located inside this repo which creates more dependencies than needed for using the lsds. One forseeable issue is that agglomeration requires networkx==2.2 for the MergeTree and boost is required for funlib.segment. We have restructured the repo to use lsd.train and lsd.post submodules. For just calculating the lsds, it is sufficient to use lsd.train, e.g:

from lsd.train import local_shape_descriptor

Quick 2d Examples

The following tutorial allows you to run in the browser using google colab. In order to replicate the tutorial locally, create a conda environment and install the relevant packages. E.g:

  1. conda create -n lsd_test python=3
  2. conda activate lsd_test
  3. pip install lsds

tutorial: Open In Colab


Notebooks

  • Examble colab notebooks are located here. You can download or run below (control + click open in colab). When running a notebook, you will probably get the message: "Warning: This notebook was not authored by Google". This can be ignored, you can run anyway.

  • We uploaded ~1.7 tb of data (raw/labels/masks/rags etc.) to an s3 bucket. The following tutorial shows some examples for accessing and visualizing the data.

    • Data download: Open In Colab
  • If implementing the LSDs in your own training pipeline (i.e pure pytorch/tensorflow), calculate the LSDs on a label array of unique objects and use them as the target for your network (see quick 2d examples above for calculating).

  • The following tutorials show how to set up 2D training/prediction pipelines using Gunpowder. It is recommended to follow them in order (skip the basic tutorial if familiar with gunpowder). Note: Google Colab can sometimes be slow especially due to data I/O. These notebooks will run much faster in a jupyter notebook on a local gpu, but the Colab versions should provide a starting point.

    • Basic Gunpowder tutorial: Open In Colab

    • Train Affinities: Open In Colab

    • Train LSDs: Open In Colab

    • Train MTLSD: Open In Colab

    • Inference (using pretrained MTLSD checkpoint): Open In Colab

    • Watershed, agglomeration, segmentation: Open In Colab

  • Bonus notebooks:

    • Training using sparse ground truth (useful if you only have a subset of training data but still want dense predictions): Open In Colab

    • Ignore regions during training (useful if you want the network to learn to predict zeros in certain regions, eg glia ids): Open In Colab

    • Train lsds on non-em data with pytorch: Open In Colab


Example networks & pipelines

  • There are example scripts for the zebrafinch networks here using the singularity container and tensorflow networks from the paper.

  • There are some example networks and training/prediction pipelines from the fib25 dataset here. These are just a guideline, several dependencies are now deprecated.

Training

  • Since networks in this paper were implemented in Tensorflow, there was a two step process for training. First the networks were created using the mknet.py files. This saved tensor placeholders and meta data in config files that were then used for both training and prediction. The mknet files used the now deprecated mala repository to create the networks. If reimplementing in Tensorflow, consider migrating to funlib.learn.tensorflow.

  • If using Pytorch, the networks can just be created directly inside the train scripts since placeholders aren't required. For example, the logic from this tensorflow mknet script and this tensorflow train script can be condensed to this pytorch train script.

  • For training an autocontext network (e.g acrlsd), the current implementation learns the LSDs in a first pass. A saved checkpoint is then used when creating the second pass in order to predict LSDs prior to learning the Affinities. One could modify this to use a single setup and remove the need for writing the LSDs to disk.

Inference

  • By default, the predict scripts (example) contain the worker logic to be distributed by the scheduler during parallel processing (see below).

  • If you just need to process a relatively small volume, it is sometimes not necessary to use blockwise processing. In this case, it is recommended to use a scan node, and specify input/output shapes + context. An example can be found in the inference colab notebook above.

  • Similar to training, the current autocontext implementations assume the predicted LSDs are written to a zarr/n5 container and then used as input to the second pass to predict affinities. This can also be changed to predict on the fly if needed.

Visualizations of example training/prediction pipelines

Vanilla affinities training:



Autocontext LSD and affinities prediction:


Parallel processing

  • If you are running on small data then this section may be irrelevant. See the Watershed, agglomeration, segmentation notebook above if you just want to get a sense of obtaining a segmentation from affinities.

  • Example processing scripts can be found here

  • We create segmentations following the approach in this paper. Generally speaking, after training a network there are five steps to obtain a segmentation:

  1. Predict boundaries (this can involve the use of LSDs as an auxiliary task)
  2. Generate supervoxels (fragments) using seeded watershed. The fragment centers of mass are stored as region adjacency graph nodes.
  3. Generate edges between nodes using hierarchical agglomeration. The edges are weighted by the underlying affinities. Edges with lower scores are merged earlier.
  4. Cut the graph at a predefined threshold and relabel connected components. Store the node - component lookup tables.
  5. Use the lookup tables to relabel supervoxels and generate a segmentation.

  • Everything was done in parallel using daisy (github, docs), but one could use multiprocessing or dask instead.

  • For our experiments we used MongoDB for all storage (block checks, rags, scores, etc) due to the size of the data. Depending on use case, it might be better to read/write to file rather than mongo. See watershed for further info.

  • The following examples were written for use with the Janelia LSF cluster and are just meant to be used as a guide. Users will likely need to customize for their own specs (for example if using a SLURM cluster).

  • Need to install funlib.segment and funlib.evaluate if using/adapting segmentation/evaluation scripts.

Inference

The worker logic is located in individual predict.py scripts (example). The master script distributes using daisy.run_blockwise. The only need for MongoDb here is for the block check function (to check which blocks have successfully completed). To remove the need for mongo, one could remove the check function (remember to also remove block_done_callback in predict.py) or replace with custom function (e.g check chunk completion directly in output container).

Example roi config
{
  "container": "hemi_roi_1.zarr",
  "offset": [140800, 205120, 198400],
  "size": [3000, 3000, 3000]
}
Example predict config
 {
  "base_dir": "/path/to/base/directory",
  "experiment": "hemi",
  "setup": "setup01",
  "iteration": 400000,
  "raw_file": "predict_roi.json",
  "raw_dataset" : "volumes/raw",
  "out_base" : "output",
  "file_name": "foo.zarr",
  "num_workers": 5,
  "db_host": "mongodb client",
  "db_name": "foo",
  "queue": "gpu_rtx",
  "singularity_image": "/path/to/singularity/image"
}

Watershed

The worker logic is located in a single script which is then distributed by the master script. By default the nodes are stored in mongo using a MongoDbGraphProvider. To write to file (i.e compressed numpy arrays), you can use the FileGraphProvider instead (inside the worker script).

Example watershed config
{
  "experiment": "hemi",
  "setup": "setup01",
  "iteration": 400000,
  "affs_file": "foo.zarr",
  "affs_dataset": "/volumes/affs",
  "fragments_file": "foo.zarr",
  "fragments_dataset": "/volumes/fragments",
  "block_size": [1000, 1000, 1000],
  "context": [248, 248, 248],
  "db_host": "mongodb client",
  "db_name": "foo",
  "num_workers": 6,
  "fragments_in_xy": false,
  "epsilon_agglomerate": 0,
  "queue": "local"
}

Agglomerate

Same as watershed. Worker script, master script. Change to FileGraphProvider if needed.

Example agglomerate config
{
  "experiment": "hemi",
  "setup": "setup01",
  "iteration": 400000,
  "affs_file": "foo.zarr",
  "affs_dataset": "/volumes/affs",
  "fragments_file": "foo.zarr",
  "fragments_dataset": "/volumes/fragments",
  "block_size": [1000, 1000, 1000],
  "context": [248, 248, 248],
  "db_host": "mongodb client",
  "db_name": "foo",
  "num_workers": 4,
  "queue": "local",
  "merge_function": "hist_quant_75"
}

Find segments

In contrast to the above three methods, when creating LUTs there just needs to be enough RAM to hold the RAG in memory. The only thing done in parallel is reading the graph (graph_provider.read_blockwise()). It could be adapted to use multiprocessing/dask for distributing the connected components for each threshold, but if the rag is too large there will be pickling errors when passing the nodes/edges. Daisy doesn't need to be used for scheduling here since nothing is written to containers.

Example find segments config
{
  "db_host": "mongodb client",
  "db_name": "foo",
  "fragments_file": "foo.zarr",
  "edges_collection": "edges_hist_quant_75",
  "thresholds_minmax": [0, 1],
  "thresholds_step": 0.02,
  "block_size": [1000, 1000, 1000],
  "num_workers": 5,
  "fragments_dataset": "/volumes/fragments",
  "run_type": "test"
}

Extract segmentation

This script does use daisy to write the segmentation to file, but doesn't necessarily require bsub/sbatch to distribute (you can run locally).

Example extract segmentation config
{
  "fragments_file": "foo.zarr",
  "fragments_dataset": "/volumes/fragments",
  "edges_collection": "edges_hist_quant_75",
  "threshold": 0.4,
  "block_size": [1000, 1000, 1000],
  "out_file": "foo.zarr",
  "out_dataset": "volumes/segmentation_40",
  "num_workers": 3,
  "run_type": "test"
}

Evaluate volumes

Evaluate Voi scores. Assumes dense voxel ground truth (not skeletons). This also assumes the ground truth (and segmentation) can fit into memory, which was fine for hemi and fib25 volumes assuming ~750 GB of RAM. The script should probably be refactored to run blockwise.

Example evaluate volumes config
{
  "experiment": "hemi",
  "setup": "setup01",
  "iteration": 400000,
  "gt_file": "hemi_roi_1.zarr",
  "gt_dataset": "volumes/labels/neuron_ids",
  "fragments_file": "foo.zarr",
  "fragments_dataset": "/volumes/fragments",
  "db_host": "mongodb client",
  "rag_db_name": "foo",
  "edges_collection": "edges_hist_quant_75",
  "scores_db_name": "scores",
  "thresholds_minmax": [0, 1],
  "thresholds_step": 0.02,
  "num_workers": 4,
  "method": "vanilla",
  "run_type": "test"
}

Evaluate annotations

For the zebrafinch, ground truth skeletons were used due to the size of the dataset. These skeletons were cropped, masked, and relabelled for the sub Rois that were tested in the paper. We evaluated voi, erl, and the mincut metric on the consolidated skeletons. The current implementation could be refactored / made more modular. It also uses node_collections which are now deprecated in daisy. To use with the current implementation, you should checkout daisy commit 39723ca.

Example evaluate annotations config
{
  "experiment": "zebrafinch",
  "setup": "setup01",
  "iteration": 400000,
  "config_slab": "mtlsd",
  "fragments_file": "foo.zarr",
  "fragments_dataset": "/volumes/fragments",
  "edges_db_host": "mongodb client",
  "edges_db_name": "foo",
  "edges_collection": "edges_hist_quant_75",
  "scores_db_name": "scores",
  "annotations_db_host": "mongo client",
  "annotations_db_name": "foo",
  "annotations_skeletons_collection_name": "zebrafinch",
  "node_components": "zebrafinch_components",
  "node_mask": "zebrafinch_mask",
  "roi_offset": [50800, 43200, 44100],
  "roi_shape": [10800, 10800, 10800],
  "thresholds_minmax": [0.5, 1],
  "thresholds_step": 1,
  "run_type": "11_micron_roi_masked"
}

lsd's People

Contributors

funkey avatar juliabuhmann avatar pattonw avatar sheridana avatar trivoldus28 avatar yajivunev avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

lsd's Issues

MergeTree class has no attribute node

MergeTree defined in "lsd/merge_tree.pyx" seems to inherit from nx.DiGraph and tries to look up nodes via self.node[node_id].

I think all of these calls should be changed to self.nodes[node_id].
networkx also isn't in the requirements file so maybe its just a versioning issue.

Ran into this error when using agglomerate_in_block.

lsd.merge_tree

Hi,
While importing AddLocalShapeDescriptor from lsd.gp in pycharm I get the error that there is no module named 'lsd.merge_tree'.
How can I fix it?
Thanks

Singularity container

Hello, I'm trying to replicate this project these days using the current network implementations. And I wonder if you could share the singularity container. Thank you!

Singularity image does not build

To reproduce:

cd singularity
make
...
+ conda install tensorflow-gpu==1.3
Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.

PackagesNotFoundError: The following packages are not available from current channels:

  - tensorflow-gpu==1.3

Current channels:

  - https://repo.anaconda.com/pkgs/main/linux-64
  - https://repo.anaconda.com/pkgs/main/noarch
  - https://repo.anaconda.com/pkgs/r/linux-64
  - https://repo.anaconda.com/pkgs/r/noarch

Cannot load saved MTLSD model

Hello,

After training the MTLSD model using the provided notebook, I saved the checkpoints. I'm trying to load the model using saved checkpoints to predict affs and lsds using this code (from segment.ipynb notebook) :

raw, pred_lsds, pred_affs = predict(checkpoint, raw_file, raw_dataset)

but I get this error:

Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/gunpowder/nodes/batch_provider.py", line 193, in request_batch batch = self.provide(upstream_request) File "/usr/local/lib/python3.10/dist-packages/gunpowder/nodes/batch_filter.py", line 148, in provide dependencies = self.prepare(request) File "/usr/local/lib/python3.10/dist-packages/gunpowder/nodes/generic_predict.py", line 116, in prepare self.start() File "/usr/local/lib/python3.10/dist-packages/gunpowder/torch/nodes/predict.py", line 107, in start self.model.load_state_dict(checkpoint["model_state_dict"]) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for MtlsdModel: Missing key(s) in state_dict: "unet.r_up.0.0.up.weight", "unet.r_up.0.0.up.bias", "unet.r_up.0.1.up.weight", "unet.r_up.0.1.up.bias".

I also tried training the model with the default dataset used in the notebooks but got the same error again. I appreciate any help or suggestions.

fetch checkpoint error

!wget https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000

--2023-08-04 05:46:22-- https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000
Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/r1u8pvji5lbanyq/model_checkpoint_50000 [following]
--2023-08-04 05:46:22-- https://www.dropbox.com/s/raw/r1u8pvji5lbanyq/model_checkpoint_50000
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 400 Bad Request
2023-08-04 05:46:23 ERROR 400: Bad Request.

Is the URL invalid?

Is computing lsd very slow comparing to the training time

Thank you for providing the good work!

I'm implementing lsd using the code in this repository in my pytorch training pipeline. I found that computing lsd on cpu takes much more time than training (~10 times when I train on a device with 1 Titan X and 8 dataloader workers).

I just want to make sure if this is common since I didn't find the training time in the paper. If it's true, do you think there are proper ways to alleviate it, such as computing lsd on a cropped label(smaller than the model input size but bigger than output size), or computing lsd and storing before training?

Thanks!

Reproducing results on zebrafinch data

I am currently trying to reproduce the MTLSD results on the zebrafinch dataset.

The dataset itself can be downloaded succesfully using the code in lsd_data_download.ipynb but I could not find any JSON config file, model checkpoint or zebrafinch-specific training or prediction script.
For the fib25 dataset there is some dataset-specific code included in the GitHub repository which I have tried to modify for zebrafinch data but was not successful yet.
The large-scale prediction scripts seem to expect a certain directory structure as indicated here: https://github.com/funkelab/lsd/blob/master/lsd/tutorial/scripts/01_predict_blockwise.py#L47-L59 - which does not seem to be included in the public code and data repositories that I have found until now.
Would you kindly share these zebrafinch-related files if that is possible?

I would also like to ask if there is a PyTorch-based version of the whole training and prediction workflow available somewhere and if there have been any updates on the Singularity image.
I am asking because the new tutorials rely on PyTorch but the public 3D prediction-related code and the Singularity image only rely on TensorFlow (related: #6).

minor installation annoyances

It would be nice to have an installation guide. pip installing usually fails for me and then once installed it usually doesn't work immediately.
Here are the issues I usually run into:

  • pip install git+git://github.com/funkelab/lsd.git usually fails with a "ModuleNotFoundError: No module named 'Cython'"
    • solvable with pyproject.toml.
  • After installation trying to import lsd throws errors until some extra packages are installed.
    • packages:
      • mahotas
      • matplotlib
      • git+git://github.com/funkelab/funlib.segment.git (this one specifically is annoying since it requires boost to install. I don't need it if I just want the lsd gunpowder node)
      • daisy
      • git+git://github.com/funkey/waterz.git
    • This isn't as problematic as the cython issue, but some of these seem like they aren't strictly necessary, especially if you only want to quickly compute some lsds on an in-memory dataset.
  • Finally, pip install from the repo can be a bit slow since there is a 19MB model checkpoint in the notebook tutorials, 23MB in .git/objects/pack, Not sure what that is.
    • distributing on pypi would probably be the best solution here so the model could stay in the notebooks. unfortunately it seems someone already took the name lsd.
    • alternatively moving the model to a separate link would probably make pip installing from the repo faster

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.