Giter Club home page Giter Club logo

Comments (4)

vaakesan-SMH avatar vaakesan-SMH commented on June 15, 2024 1

Happy to submit a merge request. I will assign myself to this. Thanks Franklin.

from cyclops.

fcogidi avatar fcogidi commented on June 15, 2024

@vaakesan-SMH Thanks for the feature request!

Is your request that a runway_plot function (or something similar) should be added to cyclops (maybe as a utility function in the evaluate module)?

For your interest, I replaced sklearn metrics with cyclops metrics in the method you provided:

import numpy as np
import plotly.graph_objects as go
import plotly.subplots as sp

from cyclops.evaluate.metrics.experimental.functional import binary_npv, binary_ppv
from cyclops.evaluate.metrics.functional import (
    binary_precision_recall_curve,
    binary_roc_curve,
)


def runway_plot(true_labels: np.ndarray, pred_probs: np.ndarray) -> go.Figure:
    """
    Plot diagnostic performance metrics with an additional histogram of predicted probabilities.
    The plot uses Plotly with a clean aesthetic. Gridlines are kept, but background color is removed.
    Y-axis ticks and labels are shown. The legend is added at the bottom.
    Tooltips show values with 3 decimal places. X-axis labels are only shown on the bottom subplot.
    The histogram's bin size is reduced and it has no borders.

    Args:
    - true_labels (np.ndarray): True binary class labels (0 or 1).
    - pred_probs (np.ndarray): Predicted probabilities for the positive class (1).

    Returns:
    - A Plotly figure containing the diagnostic performance plots and histogram.

    Example:
    ```
    # Generate synthetic data for demonstration
    true_labels = np.random.binomial(1, 0.5, 1000)
    pred_probs = np.random.uniform(0, 1, 1000)

    # Generate and show the modified faceted plot
    faceted_fig = plot_diagnostic_performance_with_histogram(true_labels, pred_probs)
    faceted_fig.show()
    ```
    """

    # ROC curve components
    fpr, tpr, _ = binary_roc_curve(true_labels, pred_probs)
    # Precision-Recall curve components
    precision, recall, _ = binary_precision_recall_curve(true_labels, pred_probs)

    # Thresholds for PPV and NPV
    thresholds = np.linspace(0, 1, 100)
    ppv = np.zeros_like(thresholds)
    npv = np.zeros_like(thresholds)

    # Calculate PPV and NPV for each threshold
    for i, threshold in enumerate(thresholds):
        # Calculate PPV and NPV
        ppv[i] = binary_ppv(true_labels, pred_probs, threshold=threshold)
        npv[i] = binary_npv(true_labels, pred_probs, threshold=threshold)

    # Define hover template to show three decimal places
    hover_template = "Threshold: %{x:.3f}<br>Metric Value: %{y:.3f}<extra></extra>"

    # Create a subplot for each metric
    fig = sp.make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02)

    # Sensitivity plot (True Positive Rate)
    fig.add_trace(
        go.Scatter(
            x=thresholds,
            y=tpr,
            mode="lines",
            name="Sensitivity",
            hovertemplate=hover_template,
        ),
        row=1,
        col=1,
    )

    # Specificity plot (1 - False Positive Rate)
    fig.add_trace(
        go.Scatter(
            x=thresholds,
            y=1 - fpr,
            mode="lines",
            name="1 - Specificity",
            hovertemplate=hover_template,
        ),
        row=2,
        col=1,
    )

    # PPV plot (Positive Predictive Value)
    fig.add_trace(
        go.Scatter(
            x=thresholds, y=ppv, mode="lines", name="PPV", hovertemplate=hover_template
        ),
        row=3,
        col=1,
    )

    # NPV plot (Negative Predictive Value)
    fig.add_trace(
        go.Scatter(
            x=thresholds, y=npv, mode="lines", name="NPV", hovertemplate=hover_template
        ),
        row=4,
        col=1,
    )

    # Add histogram of predicted probabilities
    fig.add_trace(
        go.Histogram(x=pred_probs, nbinsx=80, name="Predicted Probabilities"),
        row=5,
        col=1,
    )

    # Update layout
    fig.update_layout(
        height=1000,
        width=700,
        title_text="Diagnostic Performance Metrics by Thresholds",
        legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
    )

    # Remove subplot titles
    for i in fig["layout"]["annotations"]:
        i["text"] = ""

    # Remove the plot background color, keep gridlines, show y-axis ticks and labels
    fig.update_xaxes(showgrid=True)
    fig.update_yaxes(showgrid=True, showticklabels=True)

    # Only show the x-axis line and labels on the bottommost plot
    fig.update_xaxes(showline=True, linewidth=1, linecolor="black", mirror=True)
    fig.update_xaxes(showticklabels=True, row=4, col=1)
    fig.update_yaxes(showline=True, linewidth=1, linecolor="black", mirror=True)

    fig.update_xaxes(showline=False, row=5, col=1, showticklabels=False)
    fig.update_yaxes(showline=False, row=5, col=1)

    # Set the background to white
    fig.update_layout(plot_bgcolor="white")

    return fig


# Generate synthetic data for demonstration
true_labels = np.random.binomial(1, 0.5, 1000)
pred_probs = np.random.uniform(0, 1, 1000)

# Generate and show the modified faceted plot
faceted_fig_clean = runway_plot(true_labels, pred_probs)
faceted_fig_clean.show()

from cyclops.

vaakesan-SMH avatar vaakesan-SMH commented on June 15, 2024

Thanks @fcogidi. From my discussions with @amrit110, I believe this would be added as a method to the ClassificationPlotter class. Although perhaps it is better suited as a utility function in evaluate module. I am happy to defer to you and Amrit for the best option.

from cyclops.

fcogidi avatar fcogidi commented on June 15, 2024

I think it's better to add it to the ClassificationPlotter class.

Would you like to add it yourself? That way you'll show up as a contributor.

from cyclops.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.