Giter Club home page Giter Club logo

Comments (12)

mdancho84 avatar mdancho84 commented on June 14, 2024 1

It actually looks like it's still not working for me. Not sure what the deal is. The workaround I'm using is this serialization process:

# Serialization
write_rds(wflw_tabnet_fit, "models/wflw_tabnet_fit.rds")

torch::torch_save(wflw_tabnet_fit$fit$fit$fit$fit$network, "models/torch_network")

# Loading
torch_network <- torch::torch_load("models/torch_network")

wflw_fit <- read_rds("models/wflw_tabnet_fit.rds")

wflw_fit$fit$fit$fit$fit$network <- torch_network

wflw_fit

from tabnet.

skeydan avatar skeydan commented on June 14, 2024

you can save/load the underlying parnsip model - see https://blogs.rstudio.com/ai/posts/2021-02-11-tabnet/



fitted_model <- wf %>% fit(train)

# access the underlying parsnip model and save it to RDS format
# depending on when you read this, a nice wrapper may exist
# see https://github.com/mlverse/tabnet/issues/27  
fitted_model$fit$fit$fit %>% saveRDS("saved_model.rds")


from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Ok, for me it's giving an Error in cpp_Tensor_storage() problem when I try to load the model using read_rds().

write_rds(wflw_tabnet_fit$fit$fit$fit, "models/tabnet_fit.rds")
read_rds("models/tabnet_fit.rds")
# Error in cpp_Tensor_storage(self$ptr) : external pointer is not valid

from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Typically there's a serialization function that allows us to read the underlying model. That's what I normally have to do with TensorFlow and MXNet.

from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Ok, I've got it working. I had to save the "nn_module" using torch::torch_save(), which is in the 'network' part of the fitted workflow. This could probably be wrapped into a save/load tabnet function.

# Serialization
write_rds(wflw_tabnet_fit, "models/wflw_tabnet_fit.rds")

torch::torch_save(wflw_tabnet_fit$fit$fit$fit$fit$network, "models/torch_network")

# Loading
torch_network <- torch::torch_load("models/torch_network")

wflw_fit <- read_rds("models/wflw_tabnet_fit.rds")

wflw_fit$fit$fit$fit$fit$network <- torch_network

wflw_fit

from tabnet.

dfalbel avatar dfalbel commented on June 14, 2024

In theory, saveRDS should just work. The model object contains a version of the serialized weights, so we reload it as needed. We test this here:

test_that("serialization with saveRDS just works", {
data("ames", package = "modeldata")
x <- ames[-which(names(ames) == "Sale_Price")]
y <- ames$Sale_Price
fit <- tabnet_fit(x, y, epochs = 1)
tmp <- tempfile("model", fileext = "rds")
saveRDS(fit, tmp)
fit2 <- readRDS(tmp)
expect_equal(
predict(fit, ames),
predict(fit2, ames)
)
expect_equal(as.numeric(fit2$fit$network$.check), 1)
})

There could be a bug depending if you are using CRAN or dev versions of torch.

from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Ok, so the problem was that I was not using the development versions of the torch and tabnet packages. I can save workflows that contain fitted tabnet models with readr::write_rds(). Thanks for your help!!

Solution:

# GITHUB INSTALLATIONS ----
# TORCH: 
remotes::install_github("mlverse/torch")
# TABNET: 
remotes::install_github("mlverse/tabnet")

from tabnet.

dfalbel avatar dfalbel commented on June 14, 2024

OK! I'll investigate, I think I might have broken this with recent changes in torch.

from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Ok, for what it's worth, I'm getting pretty good results with tabnet. I'm doing some tests now, and I'll be showcasing this project during my Learning Lab 51 - Torch for Tabular Data. The saving/loading is a small inconvenience.

from tabnet.

dfalbel avatar dfalbel commented on June 14, 2024

I could not reproduce this issue any longer with dev versions... The test should be more robust with: aa1f85f
Let me know if this persists.

from tabnet.

3SMMZRjWgS avatar 3SMMZRjWgS commented on June 14, 2024

@mdancho84 You might be using readr::save_rds instead of base::saveRDS, I can confirm base::saveRDS and base::readRDS works.

from tabnet.

mdancho84 avatar mdancho84 commented on June 14, 2024

Ok thank you!

from tabnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.