Giter Club home page Giter Club logo

Comments (7)

tomasfryda avatar tomasfryda commented on June 12, 2024

Hi @MoonCapture,

  1. Calculating SHAP for stacked ensemble has a big memory requirement - it calculates baseline SHAP for every basemodel and the metalearner. SHAP for non-tree models also requires you to specify the background_frame. Rough estimate of the memory requirement for Stacked Ensemble SHAP calculation is #basemodels * (ncol(test_data) + 1) * nrow(test_data) * nrow(background_frame) (the method used to calculated the SHAP is described in https://www.nature.com/articles/s41467-022-31384-3). So one thing you could try is to calculate the SHAP for smaller subsets of the test_data and just concatenate the results. This won't give you the plot but at least you would have the SHAP values that are relatively easy to plot afterwards.

  2. Those methods output either matplotlib Figure object or object that has a .figure() method that returns matplotlib Figure object. To get the last plot you can also try using matplotlib.pyplot.gcf(). So you can change the plot using that object and if you care only about font size etc. you can also use matplotlib's rcParams.

from h2o-3.

MoonCapture avatar MoonCapture commented on June 12, 2024

Hi @MoonCapture,

  1. Calculating SHAP for stacked ensemble has a big memory requirement - it calculates baseline SHAP for every basemodel and the metalearner. SHAP for non-tree models also requires you to specify the background_frame. Rough estimate of the memory requirement for Stacked Ensemble SHAP calculation is #basemodels * (ncol(test_data) + 1) * nrow(test_data) * nrow(background_frame) (the method used to calculated the SHAP is described in https://www.nature.com/articles/s41467-022-31384-3). So one thing you could try is to calculate the SHAP for smaller subsets of the test_data and just concatenate the results. This won't give you the plot but at least you would have the SHAP values that are relatively easy to plot afterwards.
  2. Those methods output either matplotlib Figure object or object that has a .figure() method that returns matplotlib Figure object. To get the last plot you can also try using matplotlib.pyplot.gcf(). So you can change the plot using that object and if you care only about font size etc. you can also use matplotlib's rcParams.

Thank you very much for your reply, I have solved the problem 2 !
When I run:
h2o_01.explain(df_test)
The SHAP summary appears, and the model above is a "GBM" model, so the output is a model that can be 'SHAP'd' and does not reflect the interpretation of the optimal model, which doesn't feel consistent with the nature of the summary. In the future, will there be an update to support SHAP and Variable Importance for models with stacked ensembles, etc.?
In addition, is it possible to support the partial dependence diagram plot (PDP) of two variables? I am looking forward to it.

Two

from h2o-3.

tomasfryda avatar tomasfryda commented on June 12, 2024

@MoonCapture I believe you can specify the background_frame in order to get the SHAP plots for all the models that support it, e.g., h2o_01.explain(df_test, background_frame=df_train) should produce the SHAP plot with the best model if it supports SHAP (all models that are trained in AutoML support SHAP). Note that we don't require background_frame by default to keep explain() backward compatible.

The 2d pdp plot should already be supported by using more flexible api, e.g.,

model.partial_plot(data = pros, col_pairs_2dpdp = [['A', 'B'],['A', 'C']])

See the documentation for more details.

from h2o-3.

MoonCapture avatar MoonCapture commented on June 12, 2024

@MoonCapture I believe you can specify the background_frame in order to get the SHAP plots for all the models that support it, e.g., h2o_01.explain(df_test, background_frame=df_train) should produce the SHAP plot with the best model if it supports SHAP (all models that are trained in AutoML support SHAP). Note that we don't require background_frame by default to keep explain() backward compatible.

The 2d pdp plot should already be supported by using more flexible api, e.g.,

model.partial_plot(data = pros, col_pairs_2dpdp = [['A', 'B'],['A', 'C']])

See the documentation for more details.

Again thank you very much for your help, I have solved problem 2 (2dpdp,Another question is how can I modify the drawing source code locally on my computer, I want to set up other styles for presentation, like in my diagram. (The diagrams generated now are not intuitive enough for me).
pdp2d

I'm doing the regression task, df_tarin (17713,12), df_test (4298,12).
For the SHAP of the stacked ensemble model, after using background_frame = df_train it still doesn't produce results, it still shows a memory problem, my computer has 32GB of RAM and is not taking up all the memory at runtime.

shap_all_plot = best_model.shap_summary_plot(df_test,background_frame=df_train) (best_model is my dominant model)
orshap_all = h2o_01.explain(df_test, background_frame=df_train)

OSError: Job with key $03017f00000132d4ffffffff$_a9e68bb11ada5dd69bb7bb2c0d1b42c3 failed with an exception: java.lang.RuntimeException: Not enough memory. Estimated minimal total memory is 1.983856E9B. Estimated minimal per node memory (assuming perfectly balanced datasets) is 1.983856E9B. Node with minimum memory has 1095772304B. Total available memory is 1095772304B.
stacktrace: 
java.lang.RuntimeException: Not enough memory. Estimated minimal total memory is 1.983856E9B. Estimated minimal per node memory (assuming perfectly balanced datasets) is 1.983856E9B. Node with minimum memory has 1095772304B. Total available memory is 1095772304B.
	at hex.ContributionsWithBackgroundFrameTask.runAndGetOutput(ContributionsWithBackgroundFrameTask.java:178)
	at hex.tree.SharedTreeModelWithContributions.scoreContributions(SharedTreeModelWithContributions.java:134)
	at hex.ensemble.StackedEnsembleModel.baseLineContributions(StackedEnsembleModel.java:189)
	at hex.ensemble.StackedEnsembleModel.lambda$scoreContributions$8d51a081$1(StackedEnsembleModel.java:302)
	at hex.ensemble.StackedEnsembleModel.lambda$scoreContributions$1617fc2c$1(StackedEnsembleModel.java:322)
	at water.SplitToChunksApplyCombine.splitApplyCombine(SplitToChunksApplyCombine.java:52)
	at hex.ensemble.StackedEnsembleModel.scoreContributions(StackedEnsembleModel.java:322)
	at water.api.ModelMetricsHandler$1.compute2(ModelMetricsHandler.java:549)
	at water.H2O$H2OCountedCompleter.compute(H2O.java:1689)
	at jsr166y.CountedCompleter.exec(CountedCompleter.java:468)
	at jsr166y.ForkJoinTask.doExec(ForkJoinTask.java:263)
	at jsr166y.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:976)
	at jsr166y.ForkJoinPool.runWorker(ForkJoinPool.java:1479)
	at jsr166y.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)

SHAP

I don't know how to go about this anymore :(

from h2o-3.

tomasfryda avatar tomasfryda commented on June 12, 2024

@MoonCapture You can get the data for PDP by setting plot=False, e.g., model.partial_plot(df, cols=.., plot=False) then you can use any library to draw the plot. Customization of the h2o plots would be probably harder to do than creating the plot from scratch.

If you have 32GB ram, maybe the problem is that java is not allowed to use more memory. I would suggest to try starting h2o with something like max_mem_size="24g", e.g., h2o.init(max_mem_size="24g"). 24g corresponds to 24GB. You can try to increase/decrease the amount of memory but the error message says it needs ~2GB of free memory so I think allowing to use 24GB should be more than enough.

from h2o-3.

MoonCapture avatar MoonCapture commented on June 12, 2024

@MoonCapture You can get the data for PDP by setting plot=False, e.g., model.partial_plot(df, cols=.., plot=False) then you can use any library to draw the plot. Customization of the h2o plots would be probably harder to do than creating the plot from scratch.

If you have 32GB ram, maybe the problem is that java is not allowed to use more memory. I would suggest to try starting h2o with something like max_mem_size="24g", e.g., h2o.init(max_mem_size="24g"). 24g corresponds to 24GB. You can try to increase/decrease the amount of memory but the error message says it needs ~2GB of free memory so I think allowing to use 24GB should be more than enough.

I got SHAP Local Explanation by setting best_model.shap_explain_row_plot(df_test, row_index=0, background_frame=df_train)
SHAPLE

but still not working for the result of best_model.shap_explain_row _plot(df_test, row_index=0, background_frame=df_train) still doesn't work. I replaced the computer and still got the out of memory error. I still have not been able to solve the problem whether there is a calculated stack overflow.
I got the data via 'plot=False, thanks! For the following plot, is it also possible to get the data? It seems that plot=False` is not generalizable

pd_plot = model.pd_plot(df_test, 'WOB')
ra_plot = model.varimp_plot()
ra_plot = model.varimp_plot()

and other explanations for plots. I did not find this operation in the documentation, thank you very much for your answer!

from h2o-3.

tomasfryda avatar tomasfryda commented on June 12, 2024

@MoonCapture pd_plot uses the data from partial_plot and for varimp you can use varimp().

from h2o-3.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.