Giter Club home page Giter Club logo

ca-net's Introduction

CA-Net: Comprehensive Attention Comvolutional Neural Networks for Explainable Medical Image Segmentation

This repository provides the code for "CA-Net: Comprehensive attention Comvolutional Neural Networks for Explainable Medical Image Segmentation". Our work now is available on Arxiv.

mg_net Fig. 1. Structure of CA-Net.

uncertainty Fig. 2. Skin lesion segmentation.

refinement

Fig. 3. Placenta and fetal brain segmentation.

Requirementss

Some important required packages include:

  • Pytorch version >=0.4.1.
  • Visdom
  • Python == 3.7
  • Some basic python packages such as Numpy.

Follow official guidance to install Pytorch.

Usages

For skin lesion segmentation

  1. First, you can download the dataset at ISIC 2018. We only used ISIC 2018 task1 training dataset, To preprocess the dataset and save as ".npy", run:
python isic_preprocess.py 
  1. For conducting 5-fold cross-validation, split the preprocessed data into 5 fold and save their filenames. run:
python create_folder.py 
  1. To train CA-Net in ISIC 2018 (taking 1st-fold validation for example), run:
python main.py --data ISIC2018 --val_folder folder1 --id Comp_Atten_Unet
  1. To evaluate the trained model in ISIC 2018 (we added a test data in folder0, testing the 0th-fold validation for example), run:
python validation.py --data ISIC2018 --val_folder folder0 --id Comp_Atten_Unet

Our experimental results are shown in the table: refinement

  1. You can save the attention weight map in the middle step of the network to '/result' folder. Visualizing the attention weight above the original images, run:
python show_fused_heatmap.py

Visualzation of spatial attention weight map: refinement

Visualzation of scale attention weight map: refinement

Acknowledgement

Part of the code is revised from Attention-Gate-Networks.

ca-net's People

Contributors

joegue avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

ca-net's Issues

about the fused heatmap

hi sir. I am so glad to read your work.
and I need your help about how to plot the fused heatmap ?
thank you

How to compute Average symmetric surface distance

Hello Dear,

i have check your code, i think it's written in 0.4, But i am using 1.6 version of PyTorch version. I need your little help, and it will not take much time.

I want to compute ASSD measure, i think, this medpy library will easily help me to compute this measure,

Url: medpy library

Here, is my testing code, please let me know, according to my code, how i can pass result, reference in this code, so i can compute ASSD? i would really appreciate your help!

    import os
    import cv2
    import time
    import torch
    import imageio
    import numpy as np
    from glob import glob
    from tqdm import tqdm
    from operator import add
    from sklearn.metrics import (
        jaccard_score, f1_score, recall_score, precision_score, accuracy_score, fbeta_score)
    from binary import assd
    from parameters import *
    from msu import U_Net
    from utils import create_dir, seeding, make_channel_last
    from data import load_data
    from crf import apply_crf

    def calculate_metrics(y_true, y_pred):
      y_true = y_true.cpu().numpy()
      y_pred = y_pred.cpu().numpy()

      y_pred = y_pred > 0.5
      y_pred = y_pred.reshape(-1)
      y_pred = y_pred.astype(np.uint8)

      y_true = y_true > 0.5
      y_true = y_true.reshape(-1)
      y_true = y_true.astype(np.uint8)

    ## Score
    score_jaccard = jaccard_score(y_true, y_pred, average='binary')
    score_f1 = f1_score(y_true, y_pred, average='binary')
    score_recall = recall_score(y_true, y_pred, average='binary')
    score_precision = precision_score(y_true, y_pred, average='binary', zero_division=1)
    score_acc = accuracy_score(y_true, y_pred)
    score_fbeta = fbeta_score(y_true, y_pred, beta=1.0, average='binary', zero_division=1)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc, score_fbeta]

    if __name__ == "__main__":

    """ Load dataset """
    test_x = sorted(glob(os.path.join(root_path, "test/images", "*.png")))
    test_y = sorted(glob(os.path.join(root_path, "test/masks", "*.png")))

    """ Load the checkpoint """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = U_Net()
    model = model.to(device)
    model.load_state_dict(torch.load('checkpoint.pth', map_location=device))
    model.eval()

    """ Testing """
    metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

    for i, (x, y) in enumerate(zip(test_x, test_y)):
        name = y.split("/")[-1].split(".")[0]

        ## Image
        image = cv2.imread(x, cv2.IMREAD_GRAYSCALE)
        image1 = cv2.resize(image, size)
        ori_img1 = image1
        image1 = np.expand_dims(image1, axis=0)
        #image1 = np.transpose(image1, (2, 0, 1))
        image1 = image1/255.0
        image1 = np.expand_dims(image1, axis=0)
        image1 = image1.astype(np.float32)
        image1 = torch.from_numpy(image1)
        image1 = image1.to(device)

        ## Mask
        mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
        mask1 = cv2.resize(mask, size)
        ori_mask1 = mask1
        mask1 = np.expand_dims(mask1, axis=0)
        mask1 = mask1/255.0
        mask1 = np.expand_dims(mask1, axis=0)
        mask1 = mask1.astype(np.float32)
        mask1 = torch.from_numpy(mask1)
        mask1 = mask1.to(device)

        with torch.no_grad():
            pred_y1 = torch.sigmoid(model(image1))

            """ Evaluation metrics """
            score = calculate_metrics(mask1, pred_y1)
            metrics_score = list(map(add, metrics_score, score))

    jaccard = metrics_score[0]/len(test_x)
    f1 = metrics_score[1]/len(test_x)
    recall = metrics_score[2]/len(test_x)
    precision = metrics_score[3]/len(test_x)
    acc = metrics_score[4]/len(test_x)
    f2 = metrics_score[5]/len(test_x)

    print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f}")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.