Skip to main content

Easy Neural Network Experiments with pytorch

Project description

Logo

A quick and easy way to start running pytorch experiments within few minutes.

YourActionName Actions Status versions


Installation

  1. Install latest pytorch and torchvision from Pytorch official website
  2. pip install easytorch

'How to use?' you ask!

1. Define your trainer

from easytorch import ETTrainer


class MyTrainer(ETTrainer):
  def __init__(self, args):
    super().__init__(args)

  def _init_nn_model(self):
    self.nn['model'] = NeuralNetModel(self.args['num_channel'], self.args['num_class'])

  def iteration(self, batch):
    inputs = batch['input'].to(self.device['gpu']).float()
    labels = batch['label'].to(self.device['gpu']).long()

    out = self.nn['model'](inputs)
    loss = F.cross_entropy(out, labels)
    out = F.softmax(out, 1)

    _, pred = torch.max(out, 1)
    sc = self.new_metrics()
    sc.add(pred, labels)

    avg = self.new_averages()
    avg.add(loss.item(), len(inputs))

    return {'loss': loss, 'averages': avg, 'output': out, 'metrics': sc, 'predictions': pred}

2. Use custom or pytorch based Datasets class.

Define specification for your datasets:

import os
sep = os.sep
MYDATA = {
    'name': 'mydata',
    'data_dir': 'MYDATA' + sep + 'images',
    'label_dir': 'MYDATA' + sep + 'labels',
    'label_getter': lambda file_name: file_name.split('_')[0] + 'label.csv'
}

MyOTHERDATA = {
    'name': 'otherdata',
    'data_dir': 'OTHERDATA' + sep + 'images',
    'label_dir': 'OTHERDATA' + sep + 'labels',
    'label_getter': lambda file_name: file_name.split('_')[0] + 'label.csv'
}

Define how to load each data item

from easytorch import ETDataset
import torchvision
class MyDataset(ETDataset):
    def __init__(self, **kw):
        super().__init__(**kw)

    def __getitem__(self, index):
        dataset_name, file = self.indices[index]
        dataspec = self.dataspecs[dataset_name]
        
        """
        All the info. (data_dir, label_dir, label_getter...) defined above will be in dataspec.
        """
        image = #Todo # Load file/Image. 
        label = #Todo # Load corresponding label.
        # Extra preprocessing, if needed.
        # Apply transforms.
        
        return {'indices': self.indices[index],
                'input': image,
                'label': label}
    @property
    def transforms(self):
        return torchvision.transforms.Compose(["""List of transforms"""])

3. Entry point

from easytorch import EasyTorch
runner = EasyTorch([MYDATA, MyOTHERDATA],
                   phase="train", batch_size=4, epochs=21,
                   num_channel=1, num_class=2)

if __name__ == "__main__":
    runner.run(MyDataset, MyTrainer)
    runner.run_pooled(MyDataset, MyTrainer)

Complete Examples

Feature Higlights

  • For advanced training with multiple networks, and complex training steps, click here:
  • Implement custom metrics as here.
  • Minimal configuration to setup a new experiment.
  • Use your choice of Neural Network architecture.
  • Automatic k-fold cross validation/Auto dataset split.
  • Automatic logging/plotting, and model checkpointing. ..more features

Default arguments[default-value]. Easily add custom arguments.

  • -ph/--phase [Required]
    • Which phase to run? 'train' (runs all train, validation, test steps) OR 'test' (runs only test step).
  • -b/--batch_size [32]
  • -ep/--epochs [51]
  • -lr/--learning_rate [0.001]
  • -gpus/--gpus [0]
    • List of gpus to be used. Eg. [0], [1], [0, 1]
  • -nw/--num_workers [4]
    • Number of workers for data loading so that cpu can keep-up with GPU speed when loading mini-batches.
  • -lim/--load-limit[inf]
    • Specifies a limit on images/files to load for debug purpose for pipeline debugging.
  • -nf/--num_folds [None]
    • Number of folds in k-fold cross validation(Integer value like 5, 10).
  • -rt/--split_ratio [0.6 0.2 0.2]
    • Split ratio for train, validation, test set if two items given| train, test if three items given| train only if one item given.
  • ...see more

All the best! for whatever you are working on. Cheers!

Please star or cite if you find it useful.

@misc{easytorch,
  author = {Khanal, Aashis},
  title = {Easy Torch}
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  url = {https://github.com/sraashis/easytorch}
}

Project details


Release history Release notifications | RSS feed

This version

2.0.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easytorch-2.0.3.tar.gz (24.3 kB view details)

Uploaded Source

File details

Details for the file easytorch-2.0.3.tar.gz.

File metadata

  • Download URL: easytorch-2.0.3.tar.gz
  • Upload date:
  • Size: 24.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.55.1 CPython/3.9.1

File hashes

Hashes for easytorch-2.0.3.tar.gz
Algorithm Hash digest
SHA256 8814d119e49ab042ae770bbc21b71fc24a697306b747a6a7155afa235e1b0e8a
MD5 ab890ccf12b1f40790e7b14222fab578
BLAKE2b-256 9a803bdd04ffc37ee37d4e4d731171eaad25f862137a5552ffd9b0d9d7944dba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page