Skip to main content

A set of interfaces to simplify the usage of PyTorch

Project description

# torchpack

[![PyPI Version](https://img.shields.io/pypi/v/torchpack.svg)](https://pypi.python.org/pypi/torchpack)

Torchpack is a set of interfaces to simplify the usage of PyTorch.

Documentation is ongoing.


## Installation

- Install with pip.
```
pip install torchpack
```
- Install from source.
```
git clone https://github.com/hellock/torchpack.git
cd torchpack
python setup.py install
```

**Note**: If you want to use tensorboard to visualize the training process, you need to
install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`).

## What can torchpack do

Torchpack aims to help users to start training with less code, while stays
flexible and configurable. It provides a `Runner` with lots of `Hooks`.

## Example

```python
######################## file1: config.py #######################
work_dir = './demo' # dir to save log file and checkpoints
optimizer = dict(
algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4))
workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively
max_epoch = 16
lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
log_cfg = dict(
# log at every 50 iterations
interval=50,
# two logging hooks, one for printing in terminal and one for tensorboard visualization
hooks=[
('TextLoggerHook', {}),
('TensorboardLoggerHook', dict(log_dir=work_dir + '/log'))
])

######################### file2: main.py ########################
import torch
from torchpack import Config, Runner
from collections import OrderedDict

# define how to process a batch and return a dict
def batch_processor(model, data, train_mode):
img, label = data
label = label.cuda(non_blocking=True)
pred = model(img)
loss = F.cross_entropy(pred, label)
accuracy = get_accuracy(pred, label_var)
log_vars = OrderedDict()
log_vars['loss'] = loss.item()
log_vars['accuracy'] = accuracy.item()
outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs

cfg = Config.from_file('config.py') # or config.yaml/config.json
model = resnet18()
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(lr_config=cfg.lr_policy,
checkpoint_config=cfg.checkpoint_cfg,
log_config=cfg.log_cfg)

runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
```

For a full example of training on ImageNet, please see `examples/train_imagenet.py`.

```shell
python examples/train_imagenet.py examples/config.py
```

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchpack-0.1.0.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchpack-0.1.0-py3-none-any.whl (17.0 kB view details)

Uploaded Python 3

File details

Details for the file torchpack-0.1.0.tar.gz.

File metadata

  • Download URL: torchpack-0.1.0.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for torchpack-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9ec77a819c6e3b0195193c4cdde9216969b143e5d9115fc7fc163a826decec99
MD5 3af23e7ffd8b1023f3e09f72f4af4294
BLAKE2b-256 ce1ef67ff04621dca3e987ccf65cc6847f12772052aaea7418799e8c38cf0cdf

See more details on using hashes here.

File details

Details for the file torchpack-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torchpack-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 77bb7a0d10522b1978566109f5e0e418091a44678c7499e2f5ac0d34a1365584
MD5 ac30b9f5d56b13a9445a5588cbd2ad63
BLAKE2b-256 437d73e06f09236dad3101572093e341c7efbbb7d3db1f6e18f267a16dd0b726

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page