Skip to main content

A tf2 keras implementation of tabnet

Project description

TF2 Keras implementation of TabNet

TabNet is a novel deep learning architecture for tabular data. TabNet performs reasoning in multiple decision steps and using sequential attention to select which features to use at which decision step. You can find more information about it in the original research paper.

Installation

$ pip install tabnet_keras

Usage

from tabnet_keras import TabNetRegressor, TabNetClassifier

tabnet_params = {
    "decision_dim": 16,
    "attention_dim": 16,
    "n_steps": 3,
    "n_shared_glus": 2,
    "n_dependent_glus": 2,
    "relaxation_factor": 1.3,
    "epsilon": 1e-15,
    "momentum": 0.98,
    "mask_type": "sparsemax", # can be 'sparsemax' or 'softmax'
    "lambda_sparse": 1e-3, 
    "virtual_batch_splits": 8 #number of splits for ghost batch normalization, ideally should evenly divide the batch_size
}

### Regression 
model = TabNetRegressor(n_regressors = 1, **tabnet_params)
model.compile(loss = 'mean_squared_error', optimizer = tf.keras.optimizers.Adam(0.01), 
             metrics = [tf.keras.metrics.RootMeanSquaredError()])
model.fit(X, y, epochs = 100, batch_size = 1024)

### Classification
model = TabNetClassifier(n_classes = 10, out_activation = None, **tabnet_params)
model.compile(loss = 'categorical_crossentropy', optimizer = tf.keras.optimizers.Adam(0.01))
model.fit(X, y, epochs = 100, batch_size = 1024)

Acknowledgment

Most of the code is taken with minor changes from this repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tabnet_keras-1.2.0.tar.gz (11.9 kB view hashes)

Uploaded Source

Built Distribution

tabnet_keras-1.2.0-py3-none-any.whl (16.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page