Skip to main content

ktrain is a lightweight wrapper for Keras to help train neural networks

Project description

ktrain

ktrain is a lightweight wrapper for the deep learning library Keras to help build, train, and deploy neural networks. With only a few lines of code, ktrain allows you to easily and quickly:

  • estimate an optimal learning rate for your model given your data using a Learning Rate Finder
  • utilize learning rate schedules such as the triangular policy, the 1cycle policy, and SGDR to effectively minimize loss and improve generalization
  • employ fast and easy-to-use pre-canned models for both text classification (e.g., NBSVM, fastText, logreg) and image classification (e.g., ResNet, Wide ResNet, Inception)
  • load and preprocess text and image data from a variety of formats
  • inspect data points that were misclassified to help improve your model
  • leverage a simple prediction API for saving and deploying both models and data-preprocessing steps to make predictions on new raw data

Tutorial Notebooks

Please see the following tutorial notebooks for a guide on how to use ktrain on your projects:

Tasks such as text classification and image classification can be accomplished easily with only a few lines of code.

Example: Classifying Images of Dogs and Cats Using ktrain

import ktrain
from ktrain import vision as vis

# load data
(train_data, val_data, preproc) = vis.images_from_folder(
                                              datadir='data/dogscats',
                                              data_aug = vis.get_data_aug(horizontal_flip=True),
                                              train_test_names=['train', 'valid'], 
                                              target_size=(224,224), color_mode='rgb')

# load model
model = vis.image_classifier('pretrained_resnet50', train_data, val_data, freeze_layers=80)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data, 
                             workers=8, use_multiprocessing=False, batch_size=64)

# find good learning rate
learner.lr_find()             # briefly simulate training to find good learning rate
learner.lr_plot()             # visually identify best learning rate

# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(1e-4, checkpoint_folder='/tmp') 

Installation

pip3 install ktrain

Requirements

The following software/libraries should be installed:

This code was tested on Ubuntu 18.04 LTS using Keras 2.2.4 with a TensorFlow 1.10 backend. There are a few portions of the code that may explicitly depend on TensorFlow, but such dependencies are kept to a minimum.


Creator: Arun S. Maiya

Email: arun [at] maiya [dot] net

Project details


Release history Release notifications | RSS feed

This version

0.1.9

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ktrain-0.1.9.tar.gz (39.0 kB view details)

Uploaded Source

File details

Details for the file ktrain-0.1.9.tar.gz.

File metadata

  • Download URL: ktrain-0.1.9.tar.gz
  • Upload date:
  • Size: 39.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.18.4 setuptools/39.0.1 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.6.7

File hashes

Hashes for ktrain-0.1.9.tar.gz
Algorithm Hash digest
SHA256 8c69f4ebc5ef8ed9299fd089c2700ad15a7ba5e678eae8e15611d166bcfc9879
MD5 9cd618db36241379ee9c2787526fcc12
BLAKE2b-256 9fbc620b00ee130ea564528e413815ef9c6be4741444feadec597d66569fc322

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page