Giter Club home page Giter Club logo

Comments (5)

ericphanson avatar ericphanson commented on September 28, 2024 1

with the new Makie update, https://discourse.julialang.org/t/ann-makie-update-figures-and-integrated-layouts/52800, it's probably easier to do this now! Though maybe we should wait a bit first for the bugs to get worked out. (maybe @SimonDanisch can help us when the time comes too 😄).

from lighthouse.jl.

hannahilea avatar hannahilea commented on September 28, 2024 1

Closed as of #29!

from lighthouse.jl.

femtomc avatar femtomc commented on September 28, 2024

Can you slightly raise priority level for this? @SimonDanisch @hannahilea What needs to be converted? If it's a single (relatively isolated) function to plot I can try my hand.

from lighthouse.jl.

ericphanson avatar ericphanson commented on September 28, 2024

I think it's these functions:

Lighthouse.jl/src/learn.jl

Lines 228 to 428 in 66cbc09

function plot_pr_curves(per_class_pr_curves, class_labels; legend=:bottomleft)
return plot(per_class_pr_curves; labels=class_labels, title="PR curves",
xlabel="True positive rate", xlims=(0, 1), ylabel="Precision", ylims=(0, 1),
linewidth=1.5, framestyle=:box, legend=legend)
end
function plot_prg_curves(per_class_prg_curves, per_class_prg_aucs, class_labels;
legend=:bottomleft)
auc_labels = [@sprintf("%s (AUC F1: %.3f)", class, per_class_prg_aucs[i])
for (i, class) in enumerate(class_labels)]
return plot(per_class_prg_curves; labels=auc_labels, title="PR-Gain curves",
xlabel="True positive rate gain", xlims=(0, 1), ylabel="Precision gain",
ylims=(0, 1), linewidth=1.5, framestyle=:box, legend=legend)
end
function plot_roc_curves(per_class_roc_curves, per_class_roc_aucs, class_labels;
legend=:bottomright)
auc_labels = [@sprintf("%s (AUC: %.3f)", class, per_class_roc_aucs[i])
for (i, class) in enumerate(class_labels)]
return plot(per_class_roc_curves; labels=auc_labels, title="ROC curves",
xlabel="False positive rate", xlims=(0, 1), ylabel="True positive rate",
ylims=(0, 1), linewidth=1.5, framestyle=:box, legend=legend)
end
function plot_reliability_calibration_curves(per_class_reliability_calibration_curves,
per_class_reliability_calibration_scores,
class_labels; legend=:bottomright)
calibration_score_labels = [@sprintf("%s (MSE: %.3f)", class,
per_class_reliability_calibration_scores[i])
for (i, class) in enumerate(class_labels)]
plot(per_class_reliability_calibration_curves; labels=calibration_score_labels,
title="Prediction reliability calibration", xlabel="Predicted probability bin",
xlims=(0, 1), ylabel="Fraction of positives", ylims=(0, 1), markershape=:circle,
markersize=2, linewidth=1, markerstrokewidth=0, legendfontsize=1, legend=legend,
framestyle=:box)
#TODO: mean predicted value histogram underneath?? Maybe important...
# https://scikit-learn.org/stable/modules/calibration.html
plot!([0, 1], [0, 1]; linecolor=:black, linestyle=:dash, label="Ideal")
return xticks!(0:0.2:1)
end
function plot_binary_discrimination_calibration_curves(calibration_curve, calibration_score,
per_expert_calibration_curves,
per_expert_calibration_scores,
optimal_threshold,
discrimination_class::AbstractString)
p = plot(per_expert_calibration_curves; title="Detection calibration",
xlabel="Expert agreement rate", xlims=(0, 1), xticks=(0:0.2:1),
ylabel="Predicted positive probability", ylims=(0, 1), markershape=:rect,
markersize=1.5, markerstrokewidth=0, color=:darkgrey, linewidth=1,
framestyle=:box, legend=false)
plot!(p, calibration_curve; markershape=:circle, markersize=2, markerstrokewidth=0,
color=:navyblue, linewidth=1)
plot!(p, [0, 1], [0, 1]; linecolor=:black, linestyle=:dash, label="Ideal")
#TODO: expert agreement histogram underneath?? Maybe important...
# https://scikit-learn.org/stable/modules/calibration.html
return p
end
function plot_kappas(per_class_kappas, class_labels, per_class_IRA_kappas=nothing)
# Note: both the data and the labels need to be reversed, so that it plots
# with the first class at the top of plot.
class_labels_reversed = reverse(class_labels; dims=2)
shift = x -> begin
x < 0.92 && return x + 0.04
return x - 0.05
end
local p
if isnothing(per_class_IRA_kappas)
annos = [(shift(k), i, text(string(round(k; digits=3)), 7))
for (i, k) in enumerate(reverse(per_class_kappas))]
p = groupedbar(hcat(reverse(per_class_kappas)); title="Algorithm-expert agreement",
xlabel="Cohen's kappa", xlims=(-0.001, 1), xticks=0:0.2:1,
yticks=(1:length(class_labels_reversed), class_labels_reversed),
annotations=vec(annos), legend=false, color=[1], linecolor=[1],
orientation=:horizontal, framestyle=:box)
else
# Note: The plotting libraries make legend customization very tricky.
# The default horizontal groupebar() inverts the order of the labels
# in the legend relative to their vertical position in the plot, which
# makes it tricky to read. To work around this, we add two additional
# "empty" series to the plot, and then adjust the `color` and `label`
# attributes for those series to match the real data. These have
# been marked with "^plot hack" comments.
annos = [(shift(k), float(i) - 0.31, text(string(round(k; digits=3)), 6))
for (i, k) in enumerate(reverse(per_class_kappas))]
append!(annos,
[(shift(k), float(i) - 0.095, text(string(round(k; digits=3)), 6))
for (i, k) in enumerate(reverse(per_class_IRA_kappas))])
kappas = hcat(reverse(per_class_kappas), reverse(per_class_IRA_kappas),
fill(0, size(per_class_kappas)), # ^plot hack
fill(0, size(per_class_kappas))) # ^plot hack
p = groupedbar(kappas; title="Inter-rater reliability", xlabel="Cohen's kappa",
xlims=(0, 1), xticks=0:0.2:1,
yticks=(1:length(class_labels_reversed), class_labels_reversed),
labels=["" "" "Expert-vs-expert IRA" "Algorithm-vs-expert"], # ^plot hack
annotations=vec(annos), legend=:outertop,
background_color_legend=nothing, foreground_color_legend=nothing,
color=[:lightblue :lightgrey :darkgrey :lightblue], # ^plot hack
linecolor=[:lightblue :darkgrey :black :black], # ^plot hack
orientation=:horizontal, framestyle=:box)
end
return p
end
function plot_confusion_matrix(confusion::AbstractMatrix, class_labels, normalize_by)
normdim = normalize_by == :Row ? 2 :
normalize_by == :Column ? 1 :
error("normalize_by must be either :Row or :Column")
confusion = round.(confusion ./ sum(confusion; dims=normdim); digits=3)
class_indices = 1:length(class_labels)
annos = [(j, i, text(string(confusion[i, j]), 4))
for i in class_indices, j in class_indices]
heatmap(class_indices, class_indices, confusion;
title="$(string(normalize_by))-Normalized Confusion", annotations=vec(annos),
xlabel="Elected Class", ylabel="Predicted Class", clims=(0, maximum(confusion)),
fillcolor=:Blues, xticks=(class_indices, class_labels),
yticks=(class_indices, class_labels), legend=false, xrotation=45,
framestyle=:box)
return yaxis!(:flip)
end
function plot_combined(plots; class_labels, binary_discrimination_class=nothing)
labels = permutedims(string.(class_labels))
color = transpose(1:length(class_labels))
size = (600, 600)
layout = nothing
if !isnothing(binary_discrimination_class)
labels = hcat(labels, "$binary_discrimination_class discrimination", "Human expert")
color = hcat(color, :navyblue, :darkgrey)
size = (600, 600)
layout = @layout grid(3, 3)
else
@warn "Legend placement hasn't been optimized for this number of subplots"
layout = length(plots) + 1
end
legend_plot = plot(zeros(1, length(labels)); color=color, label=labels, grid=false,
showaxis=false, legendtitle="Legend", legendtitlefontsize=4,
legendtitlefonthalign=:right, background_color_legend=nothing,
foreground_color_legend=nothing, legend=:right)
return plot(plots..., legend_plot; dpi=500, legendfontsize=4, titlefontsize=8,
layout=layout, title_location=:left, guidefontsize=6, tickfontsize=6,
left_margin=3mm, size=size)
end
function evaluation_metrics_plot(plot_data::Dict)
pr = plot_pr_curves(plot_data["per_class_pr_curves"], plot_data["class_labels"];
legend=false)
prg = plot_prg_curves(plot_data["per_class_prg_curves"],
plot_data["per_class_prg_aucs"], plot_data["class_labels"];
legend=false)
roc = plot_roc_curves(plot_data["per_class_roc_curves"],
plot_data["per_class_roc_aucs"], plot_data["class_labels"];
legend=false)
IRA_kappa_data = nothing
multiclass = length(plot_data["class_labels"]) > 2
labels = multiclass ? hcat("Multiclass", plot_data["class_labels"]) :
plot_data["class_labels"]
kappa_data = multiclass ?
vcat(plot_data["multiclass_kappa"], plot_data["per_class_kappas"]) :
plot_data["per_class_kappas"]
if issubset(["multiclass_IRA_kappas", "per_class_IRA_kappas"], keys(plot_data))
IRA_kappa_data = multiclass ?
vcat(plot_data["multiclass_IRA_kappas"],
plot_data["per_class_IRA_kappas"]) :
plot_data["per_class_IRA_kappas"]
end
kappa = plot_kappas(kappa_data, labels, IRA_kappa_data)
reliability_calibration = plot_reliability_calibration_curves(plot_data["per_class_reliability_calibration_curves"],
plot_data["per_class_reliability_calibration_scores"],
plot_data["class_labels"];
legend=false)
confusion_row = plot_confusion_matrix(plot_data["confusion_matrix"],
plot_data["class_labels"], :Row)
confusion_col = plot_confusion_matrix(plot_data["confusion_matrix"],
plot_data["class_labels"], :Column)
label_str = i -> begin
class = plot_data["class_labels"][i]
auc = round(plot_data["per_class_roc_aucs"][i]; digits=2)
mse = round(plot_data["per_class_reliability_calibration_scores"][i]; digits=2)
return "$class (ROC AUC $auc; Cal. MSE $mse)"
end
class_labels = map(label_str, 1:length(plot_data["class_labels"]))
if haskey(plot_data, "discrimination_calibration_curve")
binary_calibration = plot_binary_discrimination_calibration_curves(plot_data["discrimination_calibration_curve"],
plot_data["discrimination_calibration_score"],
plot_data["per_expert_discrimination_calibration_curves"],
plot_data["per_expert_discrimination_calibration_scores"],
plot_data["optimal_threshold"],
plot_data["class_labels"][plot_data["optimal_threshold_class"]])
return plot_combined((binary_calibration, reliability_calibration, kappa,
confusion_row, confusion_col, roc, pr, prg);
class_labels=class_labels,
binary_discrimination_class=plot_data["class_labels"][plot_data["optimal_threshold_class"]])
else
return plot_combined((confusion_col, confusion_row, kappa, roc, pr, prg,
reliability_calibration); class_labels=class_labels)
end
end

The plotting is already fairly nicely separated from getting the data to plot, so I think it shouldn't be tooo hard. I think we also shouldn't try to always translate 1-1 between Plots and AbstractPlotting, but rather just try to make each subplot plot the thing it is supposed to do on its own (which might be easier than trying to get it exactly the same).

from lighthouse.jl.

SimonDanisch avatar SimonDanisch commented on September 28, 2024

I started to port a few recipes over to BeaconPlots: https://github.com/beacon-biosignals/BeaconPlots.jl/pull/34/.
Not sure if we want to leave them there going forward, but it makes things easier for me right now.
Should be easy to move if we want them somewhere else :)

from lighthouse.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.