Skip to main content

Framework that helps to train models, compare them and track parameters&metrics along the way.

Project description

Python 3.7 CodeFactor codecov

🌿 Trava ( initially stands for TrainValidation )

Framework that helps to train models, compare them and track parameters&metrics along the way. Works with tabular data only.

pip install trava

Compare models and keep track of metrics with ease!

While working on a project, we often experiment with different models looking at the same metrics. For example, we log those that can be represented as a single number, however some of them require graphs to make sense. It's also useful to save those metrics somewhere for future analysis, the list can go on.

So why not to use some unified interface for that?

Here is Trava's way:

1). Declare metrics you want to calculate:

# in this case, sk and sk_proba are just wrappers around sklearn's metrics
# but you can use any metric implementation you want
scorers = [
  sk_proba(log_loss),
  sk_proba(roc_auc_score),
  sk(recall_score),
  sk(precision_score),
]

2). What do you want to do with the metrics?

# let's log the metrics
logger_handler = LoggerHandler(scorers=scorers)

3). Initialize Trava

trava = TravaSV(results_handlers=[logger_handler])

4). Fit your model using Trava

# prepare your data
X_train, X_test, y_train, y_test = ...

split_result = SplitResult(X_train=X_train, 
                           y_train=y_train,
                           X_test=X_test,
                           y_test=y_test)

trava.fit_predict(raw_split_data=split_result,
                  model_type=GaussianNB, # we pass model class and parameters separately
                  model_init_params={},  # to be able to track them properly
                  model_id='gnb') # just a unique identifier for this model

fit_predict call does roughly the same as:

gnb = GaussianNB()
gnb.fit(split_result.X_train, split_result.y_train)
gnb.predict(split_result.X_test)

But now you don't care how the metrics you declared are calculated. You just get them in your console! Btw, those metrics definitely need to be improved. :]

Model evaluation nb
* Results for gnb model *
Train metrics:
log_loss:
16.755867191506482
roc_auc_score:
0.7746522424770221
recall_score:
0.10468384074941452
precision_score:
0.9122448979591836


Test metrics:
log_loss:
16.94514025416013
roc_auc_score:
0.829444814485889
recall_score:
0.026041666666666668
precision_score:
0.7692307692307693

After training multiple models you can get all metrics for all models by calling.

trava.results

Get the full picture and more examples by going through the guide notebook!

Built-in handlers:

  • LoggerHandler - logs metrics
  • PlotHandler - plots metrics
  • MetricsDictHandler - returns all metrics wrapped in a dict

Enable metrics autotracking. How cool is that?

Experiments tracking is a must in Data Science, so you shouldn't neglect that. You may integrate any tracking framework in Trava! Trava comes with MLFlow tracker ready-to-go. It can autotrack:

  • model's parameters
  • any metric
  • plots
  • serialized models

MLFlow example:

# get tracker's instance
tracker = MLFlowTracker(scorers=scorers)
# initialize Trava with it
trava = TravaSV(tracker=tracker)
# fit your model as before
trava.fit_predict(raw_split_data=split_result,
                  model_type=GaussianNB,
                  model_id='gnb')

Done. All model parameters and metrics are now tracked! Also supported tracking of:

  • cross-validation case with nested tracking
  • eval results for common boosting libraries ( XGBoost, LightGBM, CatBoost )

Checkout a detailed notebooks how to track metrics & parameters and plots & serialized models.

General information

  • highly customizable training & evaluation processes ( see trava.fit_predictor.py.FitPredictor class and its subclasses )
  • built-in train/test/validation split logic
  • common boosting libraries extensions ( for early-stopping with validation sets )
  • tracks model parameters, metrics, plots, serialized models. it's easy to integrate any tracking framework of your choice
  • you are also able to evaluate metrics after fit_predict call, if you forgot to add some metric
  • you are able to evaluate metrics even when your data and even a trained model are already unloaded ( depends on a metric used, true most of the times )
  • now only supervised learning problems are supported yet there is a potential to extend it to support unsupervised problems
  • unit-tested
  • I use it every day for my needs thus I care about the quality and reliability

Only sklearn-style model are supported for the time being. ( it uses fit, predict, predict_proba methods )

Requirements

pandas
numpy
python 3.7

It's also convenient to use the lib with sklearn ( e.g. for taking metrics from there. ). Also couple of extensions are based on sklearn classes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trava-0.2.8.tar.gz (358.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trava-0.2.8-py3-none-any.whl (47.3 kB view details)

Uploaded Python 3

File details

Details for the file trava-0.2.8.tar.gz.

File metadata

  • Download URL: trava-0.2.8.tar.gz
  • Upload date:
  • Size: 358.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.4.0 requests-toolbelt/0.8.0 tqdm/4.46.1 CPython/3.7.3

File hashes

Hashes for trava-0.2.8.tar.gz
Algorithm Hash digest
SHA256 182892e4347dafa5ed8c52b46d294e1e2e74007a1c81f757d5cd5d49fe96570f
MD5 8791276a374e8b29a27f066ad43bb208
BLAKE2b-256 be998e5be493d45acba88805b4bed1384b0dc2e49de795d9b485e37426057e3b

See more details on using hashes here.

File details

Details for the file trava-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: trava-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 47.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.4.0 requests-toolbelt/0.8.0 tqdm/4.46.1 CPython/3.7.3

File hashes

Hashes for trava-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 576bd928c2320cfa1e5bb641aea60f857b5c64ff1f68e31198ed55890af9f513
MD5 86e25fcc200d7b4e65b74e7a99a438bd
BLAKE2b-256 c5b70735236b6fd458938230b3287d55c10a5d3490e488a4f9ce3fc98e8919cd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page