Comments (4)
Happy to submit a merge request. I will assign myself to this. Thanks Franklin.
from cyclops.
@vaakesan-SMH Thanks for the feature request!
Is your request that a runway_plot
function (or something similar) should be added to cyclops
(maybe as a utility function in the evaluate
module)?
For your interest, I replaced sklearn
metrics with cyclops
metrics in the method you provided:
import numpy as np
import plotly.graph_objects as go
import plotly.subplots as sp
from cyclops.evaluate.metrics.experimental.functional import binary_npv, binary_ppv
from cyclops.evaluate.metrics.functional import (
binary_precision_recall_curve,
binary_roc_curve,
)
def runway_plot(true_labels: np.ndarray, pred_probs: np.ndarray) -> go.Figure:
"""
Plot diagnostic performance metrics with an additional histogram of predicted probabilities.
The plot uses Plotly with a clean aesthetic. Gridlines are kept, but background color is removed.
Y-axis ticks and labels are shown. The legend is added at the bottom.
Tooltips show values with 3 decimal places. X-axis labels are only shown on the bottom subplot.
The histogram's bin size is reduced and it has no borders.
Args:
- true_labels (np.ndarray): True binary class labels (0 or 1).
- pred_probs (np.ndarray): Predicted probabilities for the positive class (1).
Returns:
- A Plotly figure containing the diagnostic performance plots and histogram.
Example:
```
# Generate synthetic data for demonstration
true_labels = np.random.binomial(1, 0.5, 1000)
pred_probs = np.random.uniform(0, 1, 1000)
# Generate and show the modified faceted plot
faceted_fig = plot_diagnostic_performance_with_histogram(true_labels, pred_probs)
faceted_fig.show()
```
"""
# ROC curve components
fpr, tpr, _ = binary_roc_curve(true_labels, pred_probs)
# Precision-Recall curve components
precision, recall, _ = binary_precision_recall_curve(true_labels, pred_probs)
# Thresholds for PPV and NPV
thresholds = np.linspace(0, 1, 100)
ppv = np.zeros_like(thresholds)
npv = np.zeros_like(thresholds)
# Calculate PPV and NPV for each threshold
for i, threshold in enumerate(thresholds):
# Calculate PPV and NPV
ppv[i] = binary_ppv(true_labels, pred_probs, threshold=threshold)
npv[i] = binary_npv(true_labels, pred_probs, threshold=threshold)
# Define hover template to show three decimal places
hover_template = "Threshold: %{x:.3f}<br>Metric Value: %{y:.3f}<extra></extra>"
# Create a subplot for each metric
fig = sp.make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02)
# Sensitivity plot (True Positive Rate)
fig.add_trace(
go.Scatter(
x=thresholds,
y=tpr,
mode="lines",
name="Sensitivity",
hovertemplate=hover_template,
),
row=1,
col=1,
)
# Specificity plot (1 - False Positive Rate)
fig.add_trace(
go.Scatter(
x=thresholds,
y=1 - fpr,
mode="lines",
name="1 - Specificity",
hovertemplate=hover_template,
),
row=2,
col=1,
)
# PPV plot (Positive Predictive Value)
fig.add_trace(
go.Scatter(
x=thresholds, y=ppv, mode="lines", name="PPV", hovertemplate=hover_template
),
row=3,
col=1,
)
# NPV plot (Negative Predictive Value)
fig.add_trace(
go.Scatter(
x=thresholds, y=npv, mode="lines", name="NPV", hovertemplate=hover_template
),
row=4,
col=1,
)
# Add histogram of predicted probabilities
fig.add_trace(
go.Histogram(x=pred_probs, nbinsx=80, name="Predicted Probabilities"),
row=5,
col=1,
)
# Update layout
fig.update_layout(
height=1000,
width=700,
title_text="Diagnostic Performance Metrics by Thresholds",
legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
)
# Remove subplot titles
for i in fig["layout"]["annotations"]:
i["text"] = ""
# Remove the plot background color, keep gridlines, show y-axis ticks and labels
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True, showticklabels=True)
# Only show the x-axis line and labels on the bottommost plot
fig.update_xaxes(showline=True, linewidth=1, linecolor="black", mirror=True)
fig.update_xaxes(showticklabels=True, row=4, col=1)
fig.update_yaxes(showline=True, linewidth=1, linecolor="black", mirror=True)
fig.update_xaxes(showline=False, row=5, col=1, showticklabels=False)
fig.update_yaxes(showline=False, row=5, col=1)
# Set the background to white
fig.update_layout(plot_bgcolor="white")
return fig
# Generate synthetic data for demonstration
true_labels = np.random.binomial(1, 0.5, 1000)
pred_probs = np.random.uniform(0, 1, 1000)
# Generate and show the modified faceted plot
faceted_fig_clean = runway_plot(true_labels, pred_probs)
faceted_fig_clean.show()
from cyclops.
Thanks @fcogidi. From my discussions with @amrit110, I believe this would be added as a method to the ClassificationPlotter
class. Although perhaps it is better suited as a utility function in evaluate
module. I am happy to defer to you and Amrit for the best option.
from cyclops.
I think it's better to add it to the ClassificationPlotter
class.
Would you like to add it yourself? That way you'll show up as a contributor.
from cyclops.
Related Issues (20)
- Dynamically offer every method and attribute of the PyTorch model in the PTModel wrapper
- query dataset APIs design review
- Increase mypy strictness and fix issues HOT 3
- Add support for using the query op Apply with multiple column inputs
- Documentation build fails for evaluate package HOT 3
- Improve API documentation
- Code style check
- Add unit tests for utils module in data subpackage HOT 1
- Add default model params, allow override by specifying config
- Unknown type for targets and preds HOT 2
- Development Roadmap HOT 3
- Logging model parameters to report only captures string values HOT 3
- Improve fairness metrics visualization using scatter plot
- Add SQL OR, LIKE and IN to Query API
- Cyclops installation fails in Windows HOT 2
- Importing Cyclops in Kaggle Notebooks fails HOT 1
- Refactor: Use `namedtuple` instead of `tuple` in `ClassificationPlotter`
- Calibration Plots
- Refactor `create_metric_cards` in `report.utils`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cyclops.