Giter Club home page Giter Club logo

dilated_tooth_seg_net's Introduction

DilatedToothSegNet: Tooth Segmentation Network on 3D Dental Meshes Through Increasing Receptive Vision

This repository contains the code for the journal paper:
DilatedToothSegNet: Tooth Segmentation Network on 3D Dental Meshes Through Increasing Receptive Vision
Authors: Lucas Krenmayr, Reinhold von Schwerin, Daniel Schaudt, Pascal Riedel, Alexander Hafner

This paper is published in the Journal of Imaging Informatics in Medicine (JIIM) and can be accessed here.

Abstract

The utilization of advanced intraoral scanners to acquire 3D dental models has gained significant popularity in the fields of dentistry and orthodontics. Accurate segmentation and labeling of teeth on digitized 3D dental surface models are crucial for computer-aided treatment planning. At the same time, manual labeling of these models is a time-consuming task. Recent advances in geometric deep learning have demonstrated remarkable efficiency in surface segmentation when applied to raw 3D models. However, segmentation of the dental surface remains challenging due to the atypical and diverse appearance of the patients’ teeth. Numerous deep learning methods have been proposed to automate dental surface segmentation. Nevertheless, they still show limitations, particularly in cases where teeth are missing or severely misaligned. To overcome these challenges, we introduce a network operator called dilated edge convolution, which enhances the network’s ability to learn additional, more distant features by expanding its receptive field. This leads to improved segmentation results, particularly in complex and challenging cases. To validate the effectiveness of our proposed method, we performed extensive evaluations on the recently published benchmark data set for dental model segmentation Teeth3DS. We compared our approach with several other state-of-the-art methods using a quantitative and qualitative analysis. Through these evaluations, we demonstrate the superiority of our proposed method, showcasing its ability to outperform existing approaches in dental surface segmentation.

How to use this repository:

You can either use this repository to train the model from scratch using the Teeth3DS dataset or if you are just interested in the model architecture, you can find the model architecture in the models directory. The model is implemented in PyTorch and can be used in your own projects. Example usage:

import torch
from models.dilated_tooth_seg_network import DilatedToothSegmentationNetwork

# Create the model
model = DilatedToothSegmentationNetwork(num_classes=17, feature_dim=24).cuda()
# dummy input
pos = torch.rand(2, 2048, 3).cuda() # xyz coordinates of the points. Shape: (batch_size, num_points, 3)
x = torch.rand(2, 2048, 24).cuda() # features of the points. Shape: (batch_size, num_points, feature_dim)

out = model(x, pos)
print(out.shape) # Shape: (batch_size, num_points, num_classes)

Requirements

  • Python 3.10
  • PyTorch >= 2.1
  • CUDA >= 12.0
  • see requirements.txt for additional dependencies

Docker (Recommended)

It is recommended to use the provided Dockerfile to run the code. The Dockerfile contains all necessary dependencies and spins up a jupyter instance. I can be built using the following command:

docker build -t dilatedtoothsegnet .

The Docker container can be started using the following command:
On Windows:

docker run -p 8888:8888 -p 6006:6006 -e JUPYTER_TOKEN=12345 -d --gpus all -v "%cd%:/app" --name dilatedtoothsegnet dilatedtoothsegnet

On Linux:

docker run -p 8888:8888 -p 6006:6006 -e JUPYTER_TOKEN=12345 -d --gpus all -v "$(pwd):/app" --name dilatedtoothsegnet dilatedtoothsegnet

Afterwards a jupyter notebook can be started by opening the following link in a browser: http://localhost:8888/?token=12345
To finish the installation, open a terminal in the jupyter notebook and run the following command. This will install the PointNet++ Ops Library, which is required for the code to run.

cd pointnet2_ops_lib
python setup.py install

Local Installation

To install the required dependencies, run the following command:

pip install -r requirements.txt

Install PointNet++ Ops Library:

pip install ninja
cd pointnet2_ops_lib
python setup.py install

Data

The data used in this project is the Teeth3DS dataset, which can be downloaded from here
The data should be placed in the data directory in the following structure:

data
    |3dteethseg
        | raw
            | lower
            | upper
            | private-testing-set.txt
            | public-training-set-1.txt
            | public-training-set-2.txt
            | testing_lower.txt
            | testing_upper.txt
            | training_lower.txt
            | training_upper.txt

When running the training script, it will automatically preprocess the data and save it in the processed directory (data/processed).

Training

To train the model, run the following command with the desired options:

python train_network.py --epochs 100 --tb_save_dir logs --experiment_name training --experiment_version 1 --train_batch_size 2 --n_bit_precision 16 --train_test_split 1 --devices 0

To see the training progress, you can use TensorBoard by running the following command in the terminal:

tensorboard --logdir <path_to_tensorboard_logs> --port 6006 --host 0.0.0.0

Then open a browser and navigate to http://localhost:6006

Options:

--epochs <number>: The number of epochs to train for. Default is 100.
--tb_save_dir <path>: The directory to save TensorBoard logs to. Default is tensorboard_logs.
--devices <devices>: The device ids to use for training. If multiple devices are used, they should be separated by space (e.g 0 1 2). Default is 0.
--experiment_name <name>: The name of the experiment.
--experiment_version <version>: The version of the experiment.
--train_batch_size <size>: The batch size for training. Default is 2.
--n_bit_precision <number>: The precision for training. Default is 16.
--train_test_split <number>: The option for train/test split. Either 1 or 2. Default is 1.
--ckpt <path>: The path to a checkpoint to resume training from. Default None

Evaluation

To evaluate the model, run the following command with the desired options. The checkpoint path should be the path to the checkpoint to evaluate. This will save the results in the log directory as a .csv file.

python test_network.py --tb_save_dir logs --experiment_name testing --experiment_version 1 --devices 0 --n_bit_precision 16 --train_test_split 1 --ckpt <path_to_checkpoint>

Options:

--experiment_name <name>: The name of the experiment.
--experiment_version <version>: The version of the experiment.
--devices <devices>: The device ids to use for training. If multiple devices are used, they should be separated by space (e.g 0 1 2). Default is 0.
--n_bit_precision <number>: The precision for training. Default is 16.
--train_test_split <number>: The option for train/test split. Either 1 or 2. Default is 1.
--ckpt <path>: The path to a checkpoint to evaluate.

Inferencing

To use the model for inferencing on a single instance of the data set, run the following command with the desired options. The checkpoint path should be the path to the checkpoint to use for inferencing. The predicted and ground truth color-coded meshes will be saved in the output directory as .ply files. You can visualize the results using MeshLab or other 3D mesh like https://3dviewer.net/.

python visualize_example.py --ckpt <path_to_checkpoint> --out_dir output --n_bit_precision 16 --use_gpu  --train_test_split 1 --data_idx 0

Options:

--ckpt <path>: The path to a checkpoint to use for inferencing.
--out_dir <path>: The directory to save the output to. Default is output.
--n_bit_precision <number>: The precision. Default is 16.
--use_gpu: Use this flag to use the GPU for inferencing.
--train_test_split <number>: The option for train/test split. Either 1 or 2. Default is 1.
--data_idx <number>: The index of the data to use for inferencing. Default is 0.

Special Thanks

The code in this repository is based on the following repositories:

We would like to thank the authors of these repositories for providing their code.

Citation

If you find this work useful, please cite this paper: Krenmayr, L., von Schwerin, R., Schaudt, D. et al. DilatedToothSegNet: Tooth Segmentation Network on 3D Dental Meshes Through Increasing Receptive Vision. J Digit Imaging. Inform. med. (2024). https://doi.org/10.1007/s10278-024-01061-6

@article{krenmayr2024dilatedtoothsegnet,
  title={DilatedToothSegNet: Tooth Segmentation Network on 3D Dental Meshes Through Increasing Receptive Vision},
  author={Krenmayr, Lucas and von Schwerin, Reinhold and Schaudt, Daniel and Riedel, Pascal and Hafner, Alexander},
  journal={Journal of Imaging Informatics in Medicine},
  pages={1--17},
  year={2024},
  publisher={Springer},
  doi={10.1007/s10278-024-01061-6}
}




dilated_tooth_seg_net's People

Contributors

lucaskre avatar

Stargazers

 avatar  avatar Wenlong Zhao avatar  avatar Grigory Frantsuzov avatar kou avatar Hong Gi Ahn avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

simonzhao777

dilated_tooth_seg_net's Issues

NameError: name 'fps' is not defined

Hello
I met this question in running this code
he value of the num_workers argumenttonum_workers=11in theDataLoader` to improve performance.
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
File "/mnt/dilated_tooth_seg_net-main/train_network.py", line 102, in
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=args.ckpt)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 545, in fit
call._call_and_handle_interrupt(
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 581, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 990, in _run
results = self._run_stage()
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1034, in _run_stage
self._run_sanity_check()
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1063, in _run_sanity_check
val_loop.run()
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 181, in _decorator
return loop_run(self, *args, **kwargs)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 403, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "/mnt/dilated_tooth_seg_net-main/models/dilated_tooth_seg_network.py", line 109, in validation_step
pred = self.model(x, pos)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/dilated_tooth_seg_net-main/models/dilated_tooth_seg_network.py", line 63, in forward
x1, _ = self.dilated_edge_graph_conv_block1(x, pos, cd=cd)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/dilated_tooth_seg_net-main/models/layer.py", line 229, in forward
idx_fps = fps(pos.reshape(B * N, -1)[idx_l], self.k).long()
NameError: name 'fps' is not defined

How can I solve it
It happpens after processing data

Pointnet

Hello, what is the function of the point cloud operation library in the code, and why do you need to write an additional pointnet structure?

About the visualization

Could you consider providing the code to visualize the segmentation results? Just like the figure shows in the paper and Readme markdown. Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.