EASIER-net
Project description
EASIER-net
Feng, Jean, and Noah Simon. 2022. “Ensembled Sparse‐input Hierarchical Networks for High‐dimensional Datasets.” Statistical Analysis and Data Mining, March. https://doi.org/10.1002/sam.11579.
Python code for fitting EASIER-nets and reproducing all results from the paper. The python code uses PyTorch.
R code for fitting EASIER-net is available at https://github.com/jjfeng/easier_net_R.
Quick-start
Setup a python virtual environment (code runs for python 3.6) with the appropriate packages from requirements.txt.
Simulate data using by following the tutorial notebook or load your own into a npz format with x and y attributes. You may also perform GridSearchCV by following the tutorial.
To fit an EASIER-net, run
python fit_easier_net.py --n-estimators <N_ESTIMATORS> --input-filter-layer <INPUT_FILTER_LAYER> --n-layers <N_LAYERS> --n-hidden <N_HIDDEN> --input-pen <INPUT_PEN> --full-tree-pen <FULL_TREE_PEN> --batch-size <BATCH_SIZE> --num-classes <NUM_CLASSES> --weight <WEIGHT> --max-iters <MAX_ITERS> --max-prox-iters <MAX_PROX_ITERS> --model-fit-params-file <MODEL_FIT_PARAMS_FILE>
where:
N_ESTIMATORSshould be size of ensemble; the number of SIER-nets being ensembled.INPUT_FILTER_LAYERis whether to scale the inputs by parameter βN_LAYERSis the number of hidden layersN_HIDDENis the number of hidden nodes per layerINPUT_PENspecifies $\lambda_1$ in the paper; controls the input sparsityFULL_TREE_PENspecifies $\lambda_2$ in the paper; controls the number of active layers and hidden nodesBATCH_SIZEspecifies the size of the mini-batches for AdamNUM_CLASSESshould be 0 if doing regression andNUM_CLASSESshould be the number of classes if doing binary/multi-classificationWEIGHTis a list of weights for the classesMAX_ITERSis the number of epochs to run AdamMAX_PROX_ITERSis the number of epochs to run batch proximal gradient descentMODEL_FIT_PARAMS_FILEis a json file that specifies what the hyperparameters are. If given, this will override the arguments passed in.
To perform cross-validation, one should run separate fit_easier_net.py scripts for each candidate penalty parameter values.
Then select the best penalty parameter values using collate_best_param.py.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file EASIER-net-0.0.8.tar.gz.
File metadata
- Download URL: EASIER-net-0.0.8.tar.gz
- Upload date:
- Size: 25.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
660b4f231a0e3452234e05f1a33f6a735905543355afdfd9e91fa963314d2218
|
|
| MD5 |
1fc2cc32b70fb8a9c97e607362c8d964
|
|
| BLAKE2b-256 |
b179a55fe7e2c1e714485629031c251feab71c4d197dcfe660abe1bd7817e348
|
File details
Details for the file EASIER_net-0.0.8-py3-none-any.whl.
File metadata
- Download URL: EASIER_net-0.0.8-py3-none-any.whl
- Upload date:
- Size: 26.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f04b3635242a51127460fe01c357e61528d6b1ffbb744c6e590064d9da8d8d84
|
|
| MD5 |
ae9cffa3d061291316cac227f7658f77
|
|
| BLAKE2b-256 |
b60b7e2ef9f7801ee2b4317687f455fbef4fa3ace673cbf893dbfb4275dc5730
|