Skip to main content

EASIER-net

Project description

EASIER-net

Feng, Jean, and Noah Simon. 2022. “Ensembled Sparse‐input Hierarchical Networks for High‐dimensional Datasets.” Statistical Analysis and Data Mining, March. https://doi.org/10.1002/sam.11579.

Python code for fitting EASIER-nets and reproducing all results from the paper. The python code uses PyTorch.

R code for fitting EASIER-net is available at https://github.com/jjfeng/easier_net_R.

Quick-start

Setup a python virtual environment (code runs for python 3.6) with the appropriate packages from requirements.txt.

Simulate data using by following the tutorial notebook or load your own into a npz format with x and y attributes. You may also perform GridSearchCV by following the tutorial.

To fit an EASIER-net, run

python fit_easier_net.py --n-estimators <N_ESTIMATORS> --input-filter-layer <INPUT_FILTER_LAYER> --n-layers <N_LAYERS> --n-hidden <N_HIDDEN> --input-pen <INPUT_PEN> --full-tree-pen <FULL_TREE_PEN> --batch-size <BATCH_SIZE> --num-classes <NUM_CLASSES>  --weight <WEIGHT> --max-iters <MAX_ITERS> --max-prox-iters <MAX_PROX_ITERS> --model-fit-params-file <MODEL_FIT_PARAMS_FILE>

where:

  • N_ESTIMATORS should be size of ensemble; the number of SIER-nets being ensembled.
  • INPUT_FILTER_LAYER is whether to scale the inputs by parameter β
  • N_LAYERS is the number of hidden layers
  • N_HIDDEN is the number of hidden nodes per layer
  • INPUT_PEN specifies $\lambda_1$ in the paper; controls the input sparsity
  • FULL_TREE_PEN specifies $\lambda_2$ in the paper; controls the number of active layers and hidden nodes
  • BATCH_SIZE specifies the size of the mini-batches for Adam
  • NUM_CLASSES should be 0 if doing regression and NUM_CLASSES should be the number of classes if doing binary/multi-classification
  • WEIGHT is a list of weights for the classes
  • MAX_ITERS is the number of epochs to run Adam
  • MAX_PROX_ITERS is the number of epochs to run batch proximal gradient descent
  • MODEL_FIT_PARAMS_FILE is a json file that specifies what the hyperparameters are. If given, this will override the arguments passed in.

To perform cross-validation, one should run separate fit_easier_net.py scripts for each candidate penalty parameter values. Then select the best penalty parameter values using collate_best_param.py.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

EASIER-net-0.0.8.tar.gz (25.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

EASIER_net-0.0.8-py3-none-any.whl (26.7 kB view details)

Uploaded Python 3

File details

Details for the file EASIER-net-0.0.8.tar.gz.

File metadata

  • Download URL: EASIER-net-0.0.8.tar.gz
  • Upload date:
  • Size: 25.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.7

File hashes

Hashes for EASIER-net-0.0.8.tar.gz
Algorithm Hash digest
SHA256 660b4f231a0e3452234e05f1a33f6a735905543355afdfd9e91fa963314d2218
MD5 1fc2cc32b70fb8a9c97e607362c8d964
BLAKE2b-256 b179a55fe7e2c1e714485629031c251feab71c4d197dcfe660abe1bd7817e348

See more details on using hashes here.

File details

Details for the file EASIER_net-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: EASIER_net-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 26.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.7

File hashes

Hashes for EASIER_net-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 f04b3635242a51127460fe01c357e61528d6b1ffbb744c6e590064d9da8d8d84
MD5 ae9cffa3d061291316cac227f7658f77
BLAKE2b-256 b60b7e2ef9f7801ee2b4317687f455fbef4fa3ace673cbf893dbfb4275dc5730

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page