Comments (12)
It actually looks like it's still not working for me. Not sure what the deal is. The workaround I'm using is this serialization process:
# Serialization
write_rds(wflw_tabnet_fit, "models/wflw_tabnet_fit.rds")
torch::torch_save(wflw_tabnet_fit$fit$fit$fit$fit$network, "models/torch_network")
# Loading
torch_network <- torch::torch_load("models/torch_network")
wflw_fit <- read_rds("models/wflw_tabnet_fit.rds")
wflw_fit$fit$fit$fit$fit$network <- torch_network
wflw_fit
from tabnet.
you can save/load the underlying parnsip model - see https://blogs.rstudio.com/ai/posts/2021-02-11-tabnet/
fitted_model <- wf %>% fit(train)
# access the underlying parsnip model and save it to RDS format
# depending on when you read this, a nice wrapper may exist
# see https://github.com/mlverse/tabnet/issues/27
fitted_model$fit$fit$fit %>% saveRDS("saved_model.rds")
from tabnet.
Ok, for me it's giving an Error in cpp_Tensor_storage() problem when I try to load the model using read_rds()
.
write_rds(wflw_tabnet_fit$fit$fit$fit, "models/tabnet_fit.rds")
read_rds("models/tabnet_fit.rds")
# Error in cpp_Tensor_storage(self$ptr) : external pointer is not valid
from tabnet.
Typically there's a serialization function that allows us to read the underlying model. That's what I normally have to do with TensorFlow and MXNet.
from tabnet.
Ok, I've got it working. I had to save the "nn_module" using torch::torch_save()
, which is in the 'network' part of the fitted workflow. This could probably be wrapped into a save/load tabnet function.
# Serialization
write_rds(wflw_tabnet_fit, "models/wflw_tabnet_fit.rds")
torch::torch_save(wflw_tabnet_fit$fit$fit$fit$fit$network, "models/torch_network")
# Loading
torch_network <- torch::torch_load("models/torch_network")
wflw_fit <- read_rds("models/wflw_tabnet_fit.rds")
wflw_fit$fit$fit$fit$fit$network <- torch_network
wflw_fit
from tabnet.
In theory, saveRDS should just work. The model object contains a version of the serialized weights, so we reload it as needed. We test this here:
tabnet/tests/testthat/test-hardhat.R
Lines 203 to 224 in 46ac191
There could be a bug depending if you are using CRAN or dev versions of torch.
from tabnet.
Ok, so the problem was that I was not using the development versions of the torch
and tabnet
packages. I can save workflows that contain fitted tabnet
models with readr::write_rds()
. Thanks for your help!!
Solution:
# GITHUB INSTALLATIONS ----
# TORCH:
remotes::install_github("mlverse/torch")
# TABNET:
remotes::install_github("mlverse/tabnet")
from tabnet.
OK! I'll investigate, I think I might have broken this with recent changes in torch.
from tabnet.
Ok, for what it's worth, I'm getting pretty good results with tabnet
. I'm doing some tests now, and I'll be showcasing this project during my Learning Lab 51 - Torch for Tabular Data. The saving/loading is a small inconvenience.
from tabnet.
I could not reproduce this issue any longer with dev versions... The test should be more robust with: aa1f85f
Let me know if this persists.
from tabnet.
@mdancho84 You might be using readr::save_rds
instead of base::saveRDS
, I can confirm base::saveRDS
and base::readRDS
works.
from tabnet.
Ok thank you!
from tabnet.
Related Issues (20)
- move parsnip to depends HOT 1
- tabnet() needs an engine value HOT 1
- multiclass classification, why such a bad result? HOT 4
- `autoplot(model_fit)` fails with `Can't unnest elements with missing names.`
- embedding-aware attention
- multilabel / multi-output example
- Release tabnet 0.4.0
- hierarchical multilabel classification head HOT 3
- new `torch::load_state_dict()` make complex hardhat scenario fail with `Error in !self$..refer_to_state_dict.. : invalid argument type` HOT 2
- `data` inputs and outputs is not robust against character column, and fails with `Error in torch_tensor_cpp(data, dtype, device, requires_grad, pin_memory) : R type not handled` HOT 1
- multi-outcome recipe fails with wierd message @tab-network.R#707 `index out of range in self`
- Save and load model HOT 1
- Save and load model with Tidymodel and .rds HOT 1
- `node_to_df()` is not exported error
- Add support for Apple silicon GPUs HOT 1
- nomanres HOT 3
- Error in (function (self, nan, posinf, neginf) : The operator 'aten::nan_to_num.out' HOT 1
- Release tabnet 0.5.0
- Implement InterpreTabNet,
- Add a hierarchical metric to C-HMCNN
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.