Giter Club home page Giter Club logo

segmentation_metrics's Introduction

Segmentaion Metrics Package DOI

GitHub release (latest SemVer) publish workflow status codecov test workflow status Documentation Status OSCS Status

This is a simple package to compute different metrics for Medical image segmentation(images with suffix .mhd, .mha, .nii, .nii.gz or .nrrd ), and write them to csv file.

BTW, if you need the support for more suffix, just let me know by creating new issues

Summary

To assess the segmentation performance, there are several different methods. Two main methods are volume-based metrics and distance-based metrics.

Metrics included

This library computes the following performance metrics for segmentation:

Voxel based metrics

  • Dice (F-1)
  • Jaccard
  • Precision
  • Recall
  • False positive rate
  • False negtive rate
  • Volume similarity

The equations for these metrics can be seen in the wikipedia.

Surface Distance based metrics (with spacing as default)

  • Hausdorff distance
  • Hausdorff distance 95% percentile
  • Mean (Average) surface distance
  • Median surface distance
  • Std surface distance

Note: These metrics are symmetric, which means the distance from A to B is the same as the distance from B to A. More detailed explanication of these surface distance based metrics could be found here.

Installation

$ pip install seg-metrics

Getting started

Tutorial is at the Colab.

API reference is available at Documentation

Examples could be found below.

Usage

At first, import the package:

import seg_metrics.seg_metrics as sg

Evaluate two batches of images with same filenames from two different folders

labels = [0, 4, 5 ,6 ,7 , 8]
gdth_path = 'data/gdth'  # this folder saves a batch of ground truth images
pred_path = 'data/pred'  # this folder saves the same number of prediction images
csv_file = 'metrics.csv'  # results will be saved to this file and prented on terminal as well. If not set, results 
# will only be shown on terminal.

metrics = sg.write_metrics(labels=labels[1:],  # exclude background
                  gdth_path=gdth_path,
                  pred_path=pred_path,
                  csv_file=csv_file)
print(metrics)  # a list of dictionaries which includes the metrics for each pair of image.

After runing the above codes, you can get a list of dictionaries metrics which contains all the metrics. Also you can find a .csv file containing all metrics in the same directory. If the csv_file is not given, the metrics results will not be saved to disk.

Evaluate two images

labels = [0, 4, 5 ,6 ,7 , 8]
gdth_file = 'data/gdth.mhd'  # ground truth image full path
pred_file = 'data/pred.mhd'  # prediction image full path
csv_file = 'metrics.csv'

metrics = sg.write_metrics(labels=labels[1:],  # exclude background
                  gdth_path=gdth_file,
                  pred_path=pred_file,
                  csv_file=csv_file)

After runing the above codes, you can get a dictionary metrics which contains all the metrics. Also you can find a .csv file containing all metrics in the same directory.

Note:

  1. When evaluating one image, the returned metrics is a dictionary.
  2. When evaluating a batch of images, the returned metrics is a list of dictionaries.

Evaluate two images with specific metrics

labels = [0, 4, 5 ,6 ,7 , 8]
gdth_file = 'data/gdth.mhd'
pred_file = 'data/pred.mhd'
csv_file = 'metrics.csv'

metrics = sg.write_metrics(labels=labels[1:],  # exclude background if needed
                  gdth_path=gdth_file,
                  pred_path=pred_file,
                  csv_file=csv_file,
                  metrics=['dice', 'hd'])
# for only one metric
metrics = sg.write_metrics(labels=labels[1:],  # exclude background if needed
                  gdth_path=gdth_file,
                  pred_path=pred_file,
                  csv_file=csv_file,
                  metrics='msd')  

By passing the following parameters to select specific metrics.

- dice:         Dice (F-1)
- jaccard:      Jaccard
- precision:    Precision
- recall:       Recall
- fpr:          False positive rate
- fnr:          False negtive rate
- vs:           Volume similarity

- hd:           Hausdorff distance
- hd95:         Hausdorff distance 95% percentile
- msd:          Mean (Average) surface distance
- mdsd:         Median surface distance
- stdsd:        Std surface distance

For example:

labels = [1]
gdth_file = 'data/gdth.mhd'
pred_file = 'data/pred.mhd'
csv_file = 'metrics.csv'

metrics = sg.write_metrics(labels, gdth_file, pred_file, csv_file, metrics=['dice', 'hd95'])
dice = metrics['dice']
hd95 = metrics['hd95']

Evaluate two images in memory instead of disk

Note:

  1. The two images must be both numpy.ndarray or SimpleITK.Image.
  2. Input arguments are different. Please use gdth_img and pred_img instead of gdth_path and pred_path.
  3. If evaluating numpy.ndarray, the default spacing for all dimensions would be 1.0 for distance based metrics.
  4. If you want to evaluate numpy.ndarray with specific spacing, pass a sequence with the length of image dimension as spacing.
labels = [0, 1, 2]
gdth_img = np.array([[0,0,1], 
                     [0,1,2]])
pred_img = np.array([[0,0,1], 
                     [0,2,2]])
csv_file = 'metrics.csv'
spacing = [1, 2]
metrics = sg.write_metrics(labels=labels[1:],  # exclude background if needed
                  gdth_img=gdth_img,
                  pred_img=pred_img,
                  csv_file=csv_file,
                  spacing=spacing,
                  metrics=['dice', 'hd'])
# for only one metrics
metrics = sg.write_metrics(labels=labels[1:],  # exclude background if needed
                  gdth_img=gdth_img,
                  pred_img=pred_img,
                  csv_file=csv_file,
                  spacing=spacing,
                  metrics='msd')  

About the calculation of surface distance

The default surface distance is calculated based on fullyConnected border. To change the default connected type, you can set argument fullyConnected as False as follows.

metrics = sg.write_metrics(labels=[1,2,3],
                        gdth_img=gdth_img,
                        pred_img=pred_img,
                        csv_file=csv_file,
                        fully_connected=False) 

In 2D image, fullyconnected means 8 neighbor points, while faceconnected means 4 neighbor points. In 3D image, fullyconnected means 26 neighbor points, while faceconnected means 6 neighbor points.

How to obtain more metrics? like "False omission rate" or "Accuracy"?

A great number of different metrics, like "False omission rate" or "Accuracy", could be derived from some the confusion matrics. To calculate more metrics or design custom metrics, use TPTNFPFN=True to return the number of voxels/pixels of true positive (TP), true negative (TN), false positive (FP), false negative (FN) predictions. For example,

metrics = sg.write_metrics(
                        gdth_img=gdth_img,
                        pred_img=pred_img,
                        TPTNFPFN=True) 
tp, tn, fp, fn = metrics['TP'], metrics['TN'], metrics['FP'], metrics['FN']
false_omission_rate = fn/(fn+tn)
accuracy = (tp + tn)/(tp + tn + fp + fn)

Comparision with medpy

medpy also provide functions to calculate metrics for medical images. But seg-metrics
has several advantages.

  1. Faster. seg-metrics is 10 times faster calculating distance based metrics. This jupyter notebook could reproduce the results.
  2. More convenient. seg-metrics can calculate all different metrics in once in one function while medpy needs to call different functions multiple times which cost more time and code.
  3. More Powerful. seg-metrics can calculate multi-label segmentation metrics and save results to .csv file in good manner, but medpy only provides binary segmentation metrics. Comparision can be found in this jupyter notebook.

If this repository helps you in anyway, show your love ❤️ by putting a ⭐ on this project. I would also appreciate it if you cite the package in your publication. (Note: This package is NOT approved for clinical use and is intended for research use only. )

Citation

If you use this software anywhere we would appreciate if you cite the following articles:

Jia, Jingnan, Marius Staring, and Berend C. Stoel. "seg-metrics: a Python package to compute segmentation metrics." medRxiv (2024): 2024-02.

@article{jia2024seg,
  title={seg-metrics: a Python package to compute segmentation metrics},
  author={Jia, Jingnan and Staring, Marius and Stoel, Berend C},
  journal={medRxiv},
  pages={2024--02},
  year={2024},
  publisher={Cold Spring Harbor Laboratory Press}
}

segmentation_metrics's People

Contributors

jingnan-jia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

segmentation_metrics's Issues

Import error

Hello, I installed the package already but when I import it I get an error

No module named 'seg_metrics'

"Scan files are None, please check the data directory"

Hello Jia,

Congratulations on the package. It is fantastic!

I want to analyze some metrics (like dice coefficient and Hausdorff distance), but it is showing this error below:

Traceback (most recent call last):
  File "/home/brainsmclab/Desktop/Luan_Doctorate/Disconnection_MyDataset/lesionmasks_reliability/segmentation_metrics/script_metrics.py", line 24, in <module>
    metrics = sg.write_metrics(labels=labels[1:],  # exclude background
  File "/home/brainsmclab/anaconda3/lib/python3.9/site-packages/seg_metrics/seg_metrics.py", line 295, in write_metrics
    gdth_names, pred_names = get_gdth_pred_names(gdth_path, pred_path)
  File "/home/brainsmclab/anaconda3/lib/python3.9/site-packages/medutils/medutils.py", line 241, in get_gdth_pred_names
    gdth_files, pred_files = get_ct_pair_filenames(gdth_path, pred_path)
  File "/home/brainsmclab/anaconda3/lib/python3.9/site-packages/medutils/medutils.py", line 160, in get_ct_pair_filenames
    gdth_files = get_all_ct_names(gdth_path)
  File "/home/brainsmclab/anaconda3/lib/python3.9/site-packages/medutils/medutils.py", line 150, in get_all_ct_names
    raise Exception(f'Scan files are None, please check the data directory: {path}')
Exception: Scan files are None, please check the data directory: ~/Desktop/Luan_Doctorate/Disconnection_MyDataset/lesionmasks_reliability/segmentation_metrics/lesoes_milene/

My files are CT lesion masks from two examiners (".nii" format; n=20 patients per examiners).

My code:

# Import packages
import medpy
import medpy.metric
import numpy as np
import seg_metrics.seg_metrics as sg
import SimpleITK as sitk
import matplotlib.pyplot as plt
import copy
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import time
import gdown
import pandas as pd


labels = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]

gdth_path = '~/Desktop/Luan_Doctorate/Disconnection_MyDataset/lesionmasks_reliability/segmentation_metrics/lesoes_milene/'

pred_path = '~/Desktop/Luan_Doctorate/Disconnection_MyDataset/lesionmasks_reliability/segmentation_metrics/lesoes_luan/'

csv_file = 'metrics.csv'

metrics = sg.write_metrics(labels=labels[1:],  # exclude background
                  gdth_path=gdth_path,
                  pred_path=pred_path,
                  csv_file=csv_file)

print(metrics)  # a list of dictionaries which includes the metrics for each pair of image.

Could you help me with this problem?

Many thanks,
Luan.
[email protected]

TypeError: 'type' object is not subscriptable, line 144

Dear Jingnan-Jia,

When running the code:

import seg_metrics.seg_metrics as sg

def main():
labels = [0, 4, 5, 6, 7, 8]
gdth_file = 'data/MyImage.nii.gz'
pred_file = 'data/MyImage2.nii.gz'
csv_file = 'metrics.csv'

metrics = sg.write_metrics(labels=labels[1:],  # exclude background
                           gdth_path=gdth_file,
                           pred_path=pred_file,
                           csv_file=csv_file)
print(metrics)

if name == 'main':
main()

The following error is reported:

Traceback (most recent call last):
File ".../ImageMetricEval.py", line 1, in
import seg_metrics.seg_metrics as sg
File "...\lib\site-packages\seg_metrics\seg_metrics.py", line 144, in
def get_metrics_dict_all_labels(labels, gdth, pred, spacing, metrics_type=None) -> dict[str, list]:
TypeError: 'type' object is not subscriptable

When possible, please let me know how to correct this issue. Thank you for your support.

Best regards,
Austin

Unit of Volume Similarity

Could you please tell me what will be the unit of Volume Similarity? Is that mm3? Do you have any reference material?
Thanks!

Type error issue when using this package with Pytest

https://github.com/Jingnan-Jia/segmentation_metrics/blob/5387ddb07c31f65c0617ecb6ae88f8170cc305fa/seg_metrics/seg_metrics.py#L213C9-L213C66

There is a Type error issue when using this package with Pytest:

venv/lib/python3.10/site-packages/seg_metrics/seg_metrics.py:355: in write_metrics
metrics_dict_all_labels = get_metrics_dict_all_labels(labels, gdth, pred, spacing=gdth_spacing[::-1],
venv/lib/python3.10/site-packages/seg_metrics/seg_metrics.py:213: in get_metrics_dict_all_labels
logging.info('\nstart to get metrics for label: ', label)
/usr/local/lib/python3.10/logging/init.py:2138: in info
root.info(msg, *args, **kwargs)
/usr/local/lib/python3.10/logging/init.py:1477: in info
self._log(INFO, msg, args, **kwargs)
/usr/local/lib/python3.10/logging/init.py:1624: in _log
self.handle(record)
/usr/local/lib/python3.10/logging/init.py:1634: in handle
self.callHandlers(record)
/usr/local/lib/python3.10/logging/init.py:1696: in callHandlers
hdlr.handle(record)
/usr/local/lib/python3.10/logging/init.py:968: in handle
self.emit(record)
/usr/local/lib/python3.10/site-packages/_pytest/logging.py:342: in emit
super().emit(record)
/usr/local/lib/python3.10/logging/init.py:1108: in emit
self.handleError(record)
/usr/local/lib/python3.10/logging/init.py:1100: in emit
msg = self.format(record)
/usr/local/lib/python3.10/logging/init.py:943: in format
return fmt.format(record)
/usr/local/lib/python3.10/site-packages/_pytest/logging.py:113: in format
return super().format(record)
/usr/local/lib/python3.10/logging/init.py:678: in format
record.message = record.getMessage()

def getMessage(self):
    """
    Return the message for this LogRecord.

    Return the message for this LogRecord after merging any user-supplied
    arguments with the message.
    """
    msg = str(self.msg)
    if self.args:
      msg = msg % self.args

E TypeError: not all arguments converted during string formatting

The fix is to change this line to:
logging.info(f'\nstart to get metrics for label: {label}')

Open-source collaboration

Hi nice to e-meet you!

My name is Adrian and I'm an open-source maker, just created Apollo an open-source AI model aggregation library for python.

I'm reaching out because I've come across your portfolio on github and seen that you've worked with APIs, and AI models before.

How easy was it working or interfacing with ai models via API? Have you deployed or used ai models in more than a single channel in other projects? Looking to add support for a full API to manage in-app integrations, user subscription management, preferences and etc... And looking for feedback from people who had experience with some of those topics previously.

BTW, we are looking for new contributors if you'll find it interesting :)

Have an awesome day!

Code is not working for different cronology in labels

Hi,
This algorithm works well if both the manual and automated segmentation files have the same labels for each segment. That means "segment_1" of manual segmentation should match the "segment_1" of automated segmentation. Similarly, "segment_2" of manual segmentation should match the "segment_2" of automated segmentation. This code is not working if manual and automated segmentation has different label values. In my case, "segment_1" of automated segmentation matches "segment_4" of manual segmentation. Similarly, “Segment_2” of the automated segmented file matches with “Segment_3” of the manual segmented file. Is there any way to fix this issue?

Adding TP TN FP FN

`import` copy
import os
from typing import Dict, Union, Optional, Sequence
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
from medutils.medutils import load_itk, get_gdth_pred_names, one_hot_encode_3d
import logging
from tqdm import tqdm

__all__ = ["write_metrics"]


def show_itk(img: sitk.SimpleITK.Image, idx: int) -> None:
    """Show a 2D slice of 3D ITK image.

    :param itk: ITK image
    :param idx: index of 2D slice
    """
    ref_surface_array = sitk.GetArrayViewFromImage(img)
    plt.figure()
    plt.imshow(ref_surface_array[idx])
    plt.show()

    return None


def computeQualityMeasures(lP: np.ndarray,
                           lT: np.ndarray,
                           spacing: np.ndarray,
                           metrics_names: Union[Sequence, set, None] = None,
                           fullyConnected=True):
    """

    :param lP: prediction, shape (x, y, z)
    :param lT: ground truth, shape (x, y, z)
    :param spacing: shape order (x, y, z)
    :return: metrics_names: container contains metircs names
    """
    quality = {}
    labelPred = sitk.GetImageFromArray(lP, isVector=False)
    labelPred.SetSpacing(np.array(spacing).astype(np.float64))
    labelTrue = sitk.GetImageFromArray(lT, isVector=False)
    labelTrue.SetSpacing(np.array(spacing).astype(np.float64))  # spacing order (x, y, z)

    voxel_metrics = ['dice', 'jaccard', 'precision', 'recall', 'fpr', 'fnr', 'vs']
    distance_metrics = ['hd', 'hd95', 'msd', 'mdsd', 'stdsd']
    if metrics_names is None:
        metrics_names = {'dice', 'jaccard', 'precision', 'recall', 'fpr', 'fnr', 'vs', 'hd', 'hd95', 'msd', 'mdsd',
                         'stdsd'}
    else:
        metrics_names = set(metrics_names)
    # print('metrics0', metrics_names)

    # to save time, we need to determine which metrics we need to compute
    if set(voxel_metrics).intersection(metrics_names) or not metrics_names:
        pred = lP.astype(int)  # float data does not support bit_and and bit_or
        gdth = lT.astype(int)  # float data does not support bit_and and bit_or
        fp_array = copy.deepcopy(pred)  # keep pred unchanged
        fn_array = copy.deepcopy(gdth)
        gdth_sum = np.sum(gdth)
        pred_sum = np.sum(pred)
        intersection = gdth & pred
        union = gdth | pred
        intersection_sum = np.count_nonzero(intersection)
        union_sum = np.count_nonzero(union)

        tp_array = intersection

        tmp = pred - gdth
        fp_array[tmp < 1] = 0

        tmp2 = gdth - pred
        fn_array[tmp2 < 1] = 0

        tn_array = np.ones(gdth.shape) - union

        tp, fp, fn, tn = np.sum(tp_array), np.sum(fp_array), np.sum(fn_array), np.sum(tn_array)

        smooth = 0.001
        precision = tp / (pred_sum + smooth)
        recall = tp / (gdth_sum + smooth)

        fpr = fp / (fp + tn + smooth)
        fnr = fn / (fn + tp + smooth)

        jaccard = intersection_sum / (union_sum + smooth)
        dice = 2 * intersection_sum / (gdth_sum + pred_sum + smooth)

        dicecomputer = sitk.LabelOverlapMeasuresImageFilter()
        dicecomputer.Execute(labelTrue > 0.5, labelPred > 0.5)

        quality["dice"] = dice
        quality["jaccard"] = jaccard
        quality["precision"] = precision
        quality["recall"] = recall
        quality["fnr"] = fnr
        quality["fpr"] = fpr
        quality["vs"] = dicecomputer.GetVolumeSimilarity()
        quality["truepositive"] = tp
        quality["truenegative"] = tn
        quality["falsepositive"] = fp
        quality["falsenegative"] = fn
    # print('set(distance_metrics).intersection(metrics)', set(distance_metrics).intersection(metrics_names))
    # print('set(distance_metrics)', set(distance_metrics))
    # print('metrics', metrics_names)
    if set(distance_metrics).intersection(metrics_names) or not metrics_names:
        slice_idx = 300
        # Surface distance measures
        signed_distance_map = sitk.SignedMaurerDistanceMap(labelTrue > 0.5, squaredDistance=False,
                                                           useImageSpacing=True)  # It need to be adapted.
        # show_itk(signed_distance_map, slice_idx)

        ref_distance_map = sitk.Abs(signed_distance_map)
        # show_itk(ref_distance_map, slice_idx)

        ref_surface = sitk.LabelContour(labelTrue > 0.5, fullyConnected=fullyConnected)
        # show_itk(ref_surface, slice_idx)
        ref_surface_array = sitk.GetArrayViewFromImage(ref_surface)

        statistics_image_filter = sitk.StatisticsImageFilter()
        statistics_image_filter.Execute(ref_surface > 0.5)

        num_ref_surface_pixels = int(statistics_image_filter.GetSum())

        signed_distance_map_pred = sitk.SignedMaurerDistanceMap(labelPred > 0.5, squaredDistance=False,
                                                                useImageSpacing=True)
        # show_itk(signed_distance_map_pred, slice_idx)

        seg_distance_map = sitk.Abs(signed_distance_map_pred)
        # show_itk(seg_distance_map, slice_idx)

        seg_surface = sitk.LabelContour(labelPred > 0.5, fullyConnected=fullyConnected)
        # show_itk(seg_surface, slice_idx)
        seg_surface_array = sitk.GetArrayViewFromImage(seg_surface)

        seg2ref_distance_map = ref_distance_map * sitk.Cast(seg_surface, sitk.sitkFloat32)
        # show_itk(seg2ref_distance_map, slice_idx)

        ref2seg_distance_map = seg_distance_map * sitk.Cast(ref_surface, sitk.sitkFloat32)
        # show_itk(ref2seg_distance_map, slice_idx)

        statistics_image_filter.Execute(seg_surface > 0.5)

        num_seg_surface_pixels = int(statistics_image_filter.GetSum())

        seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
        seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr != 0])
        seg2ref_distances = seg2ref_distances + list(np.zeros(num_seg_surface_pixels - len(seg2ref_distances)))
        ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
        ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr != 0])
        ref2seg_distances = ref2seg_distances + list(np.zeros(num_ref_surface_pixels - len(ref2seg_distances)))  #

        all_surface_distances = seg2ref_distances + ref2seg_distances
        quality["msd"] = np.mean(all_surface_distances)
        quality["mdsd"] = np.median(all_surface_distances)
        quality["stdsd"] = np.std(all_surface_distances)
        quality["hd95"] = np.percentile(all_surface_distances, 95)
        quality["hd"] = np.max(all_surface_distances)
    return quality


def get_metrics_dict_all_labels(labels: Sequence,
                                gdth: np.ndarray,
                                pred: np.ndarray,
                                spacing: np.ndarray,
                                metrics_names: Union[Sequence, set, None] = None,
                                fullyConnected: bool = True) -> Dict[str, list]:
    """

    :param labels: not include background, e.g. [4,5,6,7,8] or [1]
    :param gdth: shape: (x, y, z, channels), channels is equal to len(labels) or equal to len(labels)+1 (background)
    :param pred: the same as above
    :param spacing: spacing order should be (x, y, z) !!!
    :param metrics_names: a list of metrics
    :param fullyConnected: if apply fully connected border during the calculation of surface distance.
    :return: metrics_dict_all_labels a dict which contain all metrics
    """
    if metrics_names is None:
        metrics_names = {'dice', 'jaccard', 'precision', 'recall', 'fpr', 'fnr', 'vs', 'hd', 'hd95', 'msd', 'mdsd',
                         'stdsd', 'truepositive', 'truenegative', 'falsepositive', 'falsenegative'}
    if type(metrics_names) is str:
        metrics_names = [metrics_names]
    hd_list = []
    dice_list = []
    jaccard_list = []
    vs_list = []
    msd_list = []
    mdsd_list = []
    stdsd_list = []
    hd95_list = []
    precision_list = []
    recall_list = []
    fpr_list = []
    fnr_list = []
    truepositive = []
    truenegative = []
    falsepositive = []
    falsenegative = []

    label_list = [lb for lb in labels]

    metrics_dict_all_labels = {'label': label_list,
                               'dice': dice_list,
                               'jaccard': jaccard_list,
                               'precision': precision_list,
                               'recall': recall_list,
                               'fpr': fpr_list,
                               'fnr': fnr_list,
                               'vs': vs_list,
                               'hd': hd_list,
                               'msd': msd_list,
                               'mdsd': mdsd_list,
                               'stdsd': stdsd_list,
                               'hd95': hd95_list,
                               'truepositive': truepositive,
                               'truenegative': truenegative,
                               'falsepositive': falsepositive,
                               'falsenegative': falsenegative}

    for i, label in enumerate(labels):
        logging.info('\nstart to get metrics for label: ', label)
        pred_per = pred[..., i]  # select onlabel
        gdth_per = gdth[..., i]
        # print('metrics-1', metrics_names)
        metrics = computeQualityMeasures(pred_per, gdth_per,
                                         spacing=spacing,
                                         metrics_names=metrics_names,
                                         fullyConnected=fullyConnected)

        for k, v in metrics_dict_all_labels.items():
            if k in metrics_names:
                v.append(metrics[k])
    #     if "jaccard" in metrics.keys():
    #     jaccard_list.append(metrics["jaccard"])
    #     precision_list.append(metrics["precision"])
    #     recall_list.append(metrics["recall"])
    #     fnr_list.append(metrics["fnr"])
    #     fpr_list.append(metrics["fpr"])
    #     vs_list.append(metrics["vs"])
    #
    #     msd_list.append(metrics["msd"])
    #     mdsd_list.append(metrics["mdsd"])
    #     stdsd_list.append(metrics["stdsd"])
    #     hd95_list.append(metrics["hd95"])
    #     hd_list.append(metrics["hd"])
    #
    # label_list = [lb for lb in labels]
    #
    # metrics_dict_all_labels = {'label': label_list,
    #                            'dice': dice_list,
    #                            'jaccard': jaccard_list,
    #                            'precision': precision_list,
    #                            'recall': recall_list,
    #                            'fpr': fpr_list,
    #                            'fnr': fnr_list,
    #                            'vs': vs_list,
    #                            'hd': hd_list,
    #                            'msd': msd_list,
    #                            'mdsd': mdsd_list,
    #                            'stdsd': stdsd_list,
    #                            'hd95': hd95_list}

    metrics_dict = {k: v for k, v in metrics_dict_all_labels.items() if v}  # remove empty values

    return metrics_dict


def type_check(gdth_path: Union[str, pathlib.Path, Sequence, None],
               pred_path: Union[str, pathlib.Path, Sequence, None],
               gdth_img: Union[np.ndarray, sitk.SimpleITK.Image, Sequence, None],
               pred_img: Union[np.ndarray, sitk.SimpleITK.Image, Sequence, None]) -> None:
    if type(gdth_img) is not type(pred_img):  # gdth and pred should have the same type
        raise Exception(f"gdth_array is {type(gdth_img)} but pred_array is {type(pred_img)}. "
                        f"They should be the same type.")
    if type(gdth_path) is not type(pred_path):  # gdth_path and pred_path should have the same type
        raise Exception(f"gdth_array is {type(gdth_path)} but pred_array is {type(pred_path)}. "
                        f"They should be the same type.")
    if type(gdth_path) is type(gdth_img):
        raise Exception(f"gdth_array is {type(gdth_path)} but pred_array is {type(pred_path)}. "
                        f"Only one of them should be None, and the other should be assigned values.")

    assert any(isinstance(gdth_path, tp) for tp in [str, pathlib.Path, Sequence, type(None)])
    assert any(isinstance(gdth_img, tp) for tp in [np.ndarray, sitk.SimpleITK.Image, Sequence, type(None)])

    if isinstance(gdth_path, Sequence):
        assert any(isinstance(gdth_path, tp) for tp in [str, pathlib.Path])
    if isinstance(gdth_img, Sequence):
        if type(gdth_img[0]) not in [np.ndarray, sitk.SimpleITK.Image]:
            raise Exception(
                f"gdth_img[0]'s type should be ndarray or SimpleITK.SimpleITK.Image, but get {type(gdth_img)}")


def write_metrics(labels: Sequence,
                  gdth_path: Union[str, pathlib.Path, Sequence, None] = None,
                  pred_path: Union[str, pathlib.Path, Sequence, None] = None,
                  csv_file: Union[str, pathlib.Path, None] = None,
                  gdth_img: Union[np.ndarray, sitk.SimpleITK.Image, Sequence, None] = None,
                  pred_img: Union[np.ndarray, sitk.SimpleITK.Image, Sequence, None] = None,
                  metrics: Union[Sequence, set, None] = None,
                  verbose: bool = True,
                  spacing: Union[Sequence, np.ndarray, None] = None,
                  fully_connected=True):
    """

    :param labels:  exclude background
    :param gdth_path: a absolute directory path or file name
    :param pred_path: a absolute directory path or file name
    :param gdth_img: np.ndarray for ground truth
    :param pred_img: np.ndarray for prediction
    :param csv_file: filename to save the metrics
    :return: metrics: a sequence which save metrics
    """
    type_check(gdth_path, pred_path, gdth_img, pred_img)
    logging.info('start to calculate metrics (volume or distance) and write them to csv')
    output_list = []
    metrics_dict_all_labels = None
    if metrics is None:
        metrics = ['dice', 'jaccard', 'precision', 'recall', 'fpr', 'fnr', 'vs', 'hd', 'hd95', 'msd', 'mdsd', 'stdsd', 'truepositive', 'truenegative', 'falsepositive', 'falsenegative']

    if gdth_path is not None:
        if os.path.isfile(gdth_path):  # gdth is a file instead of a directory
            gdth_names, pred_names = [gdth_path], [pred_path]
        else:
            gdth_names, pred_names = get_gdth_pred_names(gdth_path, pred_path)
        with tqdm(zip(gdth_names, pred_names), disable=not verbose) as pbar:
            for gdth_name, pred_name in pbar:
                pbar.set_description(f'Process {os.path.basename(pred_name)} ...')
                gdth, gdth_origin, gdth_spacing = load_itk(gdth_name, require_ori_sp=True)
                pred, pred_origin, pred_spacing = load_itk(pred_name, require_ori_sp=True)

                gdth = one_hot_encode_3d(gdth, labels=labels)
                pred = one_hot_encode_3d(pred, labels=labels)
                metrics_dict_all_labels = get_metrics_dict_all_labels(labels, gdth, pred, spacing=gdth_spacing[::-1],
                                                                      metrics_names=metrics, fullyConnected=fully_connected)
                metrics_dict_all_labels['filename'] = pred_name  # add a new key to the metrics

                if csv_file:
                    data_frame = pd.DataFrame(metrics_dict_all_labels)
                    data_frame.to_csv(csv_file, mode='a', header=not os.path.exists(csv_file), index=False)
                output_list.append(metrics_dict_all_labels)

    if gdth_img is not None:
        if type(gdth_img) in [sitk.SimpleITK.Image, np.ndarray]:  # gdth is a file instead of a list
            gdth_img, pred_img = [gdth_img], [pred_img]
        with tqdm(zip(gdth_img, pred_img), disable=not verbose) as pbar:
            img_id = 0
            for gdth, pred in pbar:
                img_id += 1
                if type(gdth) not in [sitk.SimpleITK.Image, np.ndarray]:
                    raise TypeError(f"image type should be sitk.SimpleITK.Image or np.ndarray, but is {type(gdth)}")
                if isinstance(gdth, sitk.SimpleITK.Image):
                    gdth_array = sitk.GetArrayFromImage(gdth)
                    pred_array = sitk.GetArrayFromImage(pred)

                    gdth_spacing = np.array(list(reversed(gdth.GetSpacing())))  # after reverseing, spacing =(z,y,x)
                    pred_spacing = np.array(list(reversed(pred.GetSpacing())))  # after reverseing, spacing =(z,y,x)
                    assert all(gdth_spacing == pred_spacing)
                    gdth_orientation = gdth.GetDirection()
                    if gdth_orientation[-1] == -1:
                        gdth_array = gdth_array[::-1]
                    pred_orientation = pred.GetDirection()
                    if pred_orientation[-1] == -1:
                        pred_array = pred_array[::-1]

                    gdth = gdth_array
                    pred = pred_array
                else:  # numpy.Ndarray
                    if spacing is None:
                        if gdth.ndim == 2:
                            gdth_spacing = np.array([1., 1.])  # spacing should be double
                        elif gdth.ndim == 3:
                            gdth_spacing = np.array([1., 1., 1.])  # spacing should be double
                        else:
                            raise Exception(f"The dimension of gdth should be 2 or 3, but it is {gdth.ndim}")
                    else:
                        gdth_spacing = np.array(spacing).astype(np.float64)
                        if len(gdth_spacing) not in (2, 3):
                            raise Exception(f"The length of spacing should be 2 or 3, but the spacing is {gdth_spacing} "
                                            f"with length of {len(gdth_spacing)}")


                gdth = one_hot_encode_3d(gdth, labels=labels)
                pred = one_hot_encode_3d(pred, labels=labels)
                metrics_dict_all_labels = get_metrics_dict_all_labels(labels, gdth, pred, spacing=gdth_spacing[::-1],
                                                                      metrics_names=metrics, fullyConnected=fully_connected)
                # metrics_dict_all_labels['image_number'] = img_id  # add a new key to the metrics

                if csv_file:
                    data_frame = pd.DataFrame(metrics_dict_all_labels)
                    data_frame.to_csv(csv_file, mode='a', header=not os.path.exists(csv_file), index=False)
                output_list.append(metrics_dict_all_labels)
    if csv_file:
        logging.info('Metrics were saved at : ', csv_file)

    if metrics_dict_all_labels is None:
        if gdth_path is not None:
            raise Exception(f"The metrics are None, because no files were detected in folder: {gdth_path} or folder: {pred_path}")
        # if gdth_img is not None:
        #     raise Exception(f"The metrics are None,because give image is None")
    if len(output_list)==0:
        return metrics_dict_all_labels
    else:
        return output_list


def main():
    labels = [0, 4, 5, 6, 7, 8]
    gdth_path = 'data/gdth'
    pred_path = 'data/pred'
    csv_file = 'metrics.csv'

    write_metrics(labels=labels[1:],  # exclude background
                  gdth_path=gdth_path,
                  pred_path=pred_path,
                  csv_file=csv_file)


if __name__ == "__main__":
    main()

Logging issues

There is issue with logging in "sg.write_metrics".
seg-metrics==1.1.3

--- Logging error ---
Traceback (most recent call last):
File "/usr/lib/python3.9/logging/init.py", line 1083, in emit
msg = self.format(record)
File "/usr/lib/python3.9/logging/init.py", line 927, in format
return fmt.format(record)
File "/usr/lib/python3.9/logging/init.py", line 663, in format
record.message = record.getMessage()
File "/usr/lib/python3.9/logging/init.py", line 367, in getMessage
msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
...
sg.write_metrics(labels=[1], gdth_img=gt, pred_img=pred, verbose=False,
File "/home/nmakarov/sg/lib/python3.9/site-packages/seg_metrics/seg_metrics.py", line 355, in write_metrics
metrics_dict_all_labels = get_metrics_dict_all_labels(labels, gdth, pred, spacing=gdth_spacing[::-1],
File "/home/nmakarov/sg/lib/python3.9/site-packages/seg_metrics/seg_metrics.py", line 213, in get_metrics_dict_all_labels
logging.info('\nstart to get metrics for label: ', label)
Message: '\nstart to get metrics for label: '
Arguments: (1,)

Computing values on ground truth does not give perfect scores

Following one of the examples in the README file, a quick test using a ground truth image as the predicted image,

labels = [0, 1, 2]
gdth_img = np.array([[0,0,1], [0,1,2]])
metrics = sg.write_metrics(labels=labels[1:], gdth_img=gdth_img, pred_img=gdth_img)

does not give perfect scores: i.e. the dice score, Jaccard index, precision and recall are not 1.0. Although they are close (0.999) they should be a perfect 1.0.

Add .png suffix

Hi, Is it possible to add the .png suffix please ?
Thanks a lot

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.