Comments (6)
@JackyP have you seen any example of how to use tabnet with categorical variables?
My go-to would be
if (type == "factor") {
tensorflow::tf$feature_column$indicator_column(
tensorflow::tf$feature_column$categorical_column_with_vocabulary_list(id, levels[[1]][[id]])
)
}
but for the data, neither matrix(features[, get(x)])
, nor converting to integer seem to work.
Would be cool to have some example where it is used, but I can not find anything I could work with.
from mlr3keras.
The original tabnet paper references mapping of categorical features with trainable embeddings, and both tf-tabnet which we are using, and fast.ai tabular use embeddings to fit categorical variables.
- The fast.ai's introduction is here: https://www.fast.ai/2018/04/29/categorical-embeddings/ *
- The tf-tabnet example in Python is here: https://github.com/titu1994/tf-TabNet/blob/master/examples/train_embedding.py.
- an r example from Rstudio is here: https://blogs.rstudio.com/tensorflow/posts/2018-11-26-embeddings-fun-and-profit/
- And the original google codebase (Tensorflow 1.1x) includes this (https://github.com/google-research/google-research/blob/master/tabnet/data_helper_covertype.py)
def get_columns():
"""Get the representations for all input columns."""
columns = []
if float_columns:
columns += [tf.feature_column.numeric_column(ci) for ci in float_columns]
if int_columns:
columns += [tf.feature_column.numeric_column(ci) for ci in int_columns]
if str_columns:
# pylint: disable=g-complex-comprehension
columns += [
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_hash_bucket(
ci, hash_bucket_size=int(3 * num)),
dimension=1) for ci, num in zip(str_columns, str_nuniques)
]
if bool_columns:
# pylint: disable=g-complex-comprehension
columns += [
tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_hash_bucket(
ci, hash_bucket_size=3),
dimension=1) for ci in bool_columns
]
return columns
Your make_tf_feature_cols
is very similar to the Google-original-code example, would str_columns
from it work?
EDIT: No it does not...
from mlr3keras.
So I understood that TabNet
basically expects a list feature_columns
.
This seems to work as long as all feature_columns
are numeric.
I then thought that an indicator_column
(or embedding_column
) would work for categorical variables.
Doing this, I can build the tabnet learner, but fitting fails.
Possible reasons:
- The 'x' data is in the wrong format
- TabNet does not work as I thought it would
Additional observation: tf.feature_column.numeric_column
seems to have a shape, while tf.feature_column.embedding_column
does not?
from mlr3keras.
So my understanding of tensorflow
falls a little short so this might be grasping at straws but:
- What type is
indicator_column
vsnumeric_column
? - It looks like
embedding_column
example inpy-tabnet
requires wrapping the whole thing so to train the embedding...
Otherwise I'm also stuck.
from mlr3keras.
Minimal working example:
library("reticulate")
library("tensorflow")
library("keras")
# keras::install_keras(extra_packages = c("tensorflow-hub", "tabnet==0.1.4.1"))
use_implementation("tensorflow")
tabnet <- import("tabnet") #0.1.4.1
col_names = c('Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Species')
feature_columns <- lapply(col_names, function(x) {
if(x == "Species") {
tf$feature_column$indicator_column(
tf$feature_column$categorical_column_with_vocabulary_list(
x, levels(iris$Species))
)
} else {
tf$feature_column$numeric_column(x)
}
})
# The trick is that the one hot categorical counts as multiple columns for num_features
model = tabnet$TabNetRegressor(feature_columns, num_regressors=1, num_features = 3 + length(levels(iris$Species)),
feature_dim=4, output_dim=4,
num_decision_steps=2, relaxation_factor=1.0,
sparsity_coefficient=1e-5, batch_momentum=0.98,
virtual_batch_size=NULL, norm_type='group',
num_groups=1)
model %>% compile(
loss='mean_squared_error',
optimizer=optimizer_adam()
)
x <- lapply(col_names, function(x) { as.matrix(iris[x])})
names(x) <- col_names
y <- model.matrix(~ 0 + iris$Petal.Width)
model %>%
fit(x, y, epochs=100, verbose=2)
from mlr3keras.
Thank you very much!
Perfect! I think I got it now, see #16
from mlr3keras.
Related Issues (20)
- Connect tfestimators
- Connect tfdatasets
- use sgdr scheduler tf impl HOT 1
- kerasff do not auto-use sigmoid HOT 1
- bug with benchmarking keras model? HOT 3
- test regression works for all learners
- Restore weights to continue training
- Flexibility wrt tf and keras versions HOT 1
- Compare to mlr3proba param spaces and implementation HOT 3
- See why we have accuracy = 0 in one of the figures HOT 1
- validation data split from mlr3 HOT 1
- Push down mlr3 metrics as custom metric in keras? HOT 1
- The `imagepathdf_from_imagenet_dir` function HOT 1
- Load `library("mlr3")` in examples in README HOT 2
- Check for presence of `num_layers` when `stacked` is `TRUE`
- stacked TabNet doesn't seem to work HOT 2
- Error: Error installing package(s)
- review
- "classif.kerascnn" Feature Types Errors
- Error: Input has undefined `axis` dimension
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlr3keras.