A set of interfaces to simplify the usage of PyTorch
Project description
# torchpack
[](https://pypi.python.org/pypi/torchpack)
Torchpack is a set of interfaces to simplify the usage of PyTorch.
Documentation is ongoing.
## Installation
- Install with pip.
```
pip install torchpack
```
- Install from source.
```
git clone https://github.com/hellock/torchpack.git
cd torchpack
python setup.py install
```
**Note**: If you want to use tensorboard to visualize the training process, you need to
install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`).
## What can torchpack do
Torchpack aims to help users to start training with less code, while stays
flexible and configurable. It provides a `Runner` with lots of `Hooks`.
## Example
```python
######################## file1: config.py #######################
work_dir = './demo' # dir to save log file and checkpoints
optimizer = dict(
algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4))
workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively
max_epoch = 16
lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
log_cfg = dict(
# log at every 50 iterations
interval=50,
# two logging hooks, one for printing in terminal and one for tensorboard visualization
hooks=[
('TextLoggerHook', {}),
('TensorboardLoggerHook', dict(log_dir=work_dir + '/log'))
])
######################### file2: main.py ########################
import torch
from torchpack import Config, Runner
from collections import OrderedDict
# define how to process a batch and return a dict
def batch_processor(model, data, train_mode):
img, label = data
volatile = False if train_mode else True
img_var = torch.autograd.Variable(img, volatile=volatile)
label_var = torch.autograd.Variable(label, requires_grad=False)
pred = model(img)
loss = F.cross_entropy(pred, label_var)
accuracy = get_accuracy(pred, label_var)
log_vars = OrderedDict()
log_vars['loss'] = loss.data[0]
log_vars['accuracy'] = accuracy.data[0]
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs
cfg = Config.from_file('config.py') # or config.yaml/config.json
model = resnet18()
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(cfg.lr_policy, cfg.checkpoint_cfg, cfg.log_cfg)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
```
[](https://pypi.python.org/pypi/torchpack)
Torchpack is a set of interfaces to simplify the usage of PyTorch.
Documentation is ongoing.
## Installation
- Install with pip.
```
pip install torchpack
```
- Install from source.
```
git clone https://github.com/hellock/torchpack.git
cd torchpack
python setup.py install
```
**Note**: If you want to use tensorboard to visualize the training process, you need to
install tensorflow([`installation guide`](https://www.tensorflow.org/install/install_linux)) and tensorboardX(`pip install tensorboardX`).
## What can torchpack do
Torchpack aims to help users to start training with less code, while stays
flexible and configurable. It provides a `Runner` with lots of `Hooks`.
## Example
```python
######################## file1: config.py #######################
work_dir = './demo' # dir to save log file and checkpoints
optimizer = dict(
algorithm='SGD', args=dict(lr=0.001, momentum=0.9, weight_decay=5e-4))
workflow = [('train', 2), ('val', 1)] # train 2 epochs and then validate 1 epochs, iteratively
max_epoch = 16
lr_policy = dict(policy='step', step=12) # decrese learning rate by 10 every 12 epochs
checkpoint_cfg = dict(interval=1) # save checkpoint at every epoch
log_cfg = dict(
# log at every 50 iterations
interval=50,
# two logging hooks, one for printing in terminal and one for tensorboard visualization
hooks=[
('TextLoggerHook', {}),
('TensorboardLoggerHook', dict(log_dir=work_dir + '/log'))
])
######################### file2: main.py ########################
import torch
from torchpack import Config, Runner
from collections import OrderedDict
# define how to process a batch and return a dict
def batch_processor(model, data, train_mode):
img, label = data
volatile = False if train_mode else True
img_var = torch.autograd.Variable(img, volatile=volatile)
label_var = torch.autograd.Variable(label, requires_grad=False)
pred = model(img)
loss = F.cross_entropy(pred, label_var)
accuracy = get_accuracy(pred, label_var)
log_vars = OrderedDict()
log_vars['loss'] = loss.data[0]
log_vars['accuracy'] = accuracy.data[0]
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=img.size(0))
return outputs
cfg = Config.from_file('config.py') # or config.yaml/config.json
model = resnet18()
runner = Runner(model, cfg.optimizer, batch_processor, cfg.work_dir)
runner.register_default_hooks(cfg.lr_policy, cfg.checkpoint_cfg, cfg.log_cfg)
runner.run([train_loader, val_loader], cfg.workflow, cfg.max_epoch)
```
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchpack-0.0.9.tar.gz
(11.7 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
torchpack-0.0.9-py3-none-any.whl
(17.6 kB
view details)
File details
Details for the file torchpack-0.0.9.tar.gz.
File metadata
- Download URL: torchpack-0.0.9.tar.gz
- Upload date:
- Size: 11.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fc98b5b2a17caa3176d0e424c83af4bc46af33875791c2dcd80a8b55045c57e0
|
|
| MD5 |
f680e9281876e4c81687cdcbe584a551
|
|
| BLAKE2b-256 |
9a8dd39501f9e3eb94f03be809095ff4aebb7726e6cd6ea57cb5579e2918d5ed
|
File details
Details for the file torchpack-0.0.9-py3-none-any.whl.
File metadata
- Download URL: torchpack-0.0.9-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e7e1b09b24e3b38632fb2acd1eaad9bcf21b9ee710070b2332c5549b928631b2
|
|
| MD5 |
5ad4724317d755b09ea34207ec210eef
|
|
| BLAKE2b-256 |
13dfaff001ea0603ac5813a1dbbf432d20ad4eb853b3e7f0b8591a4063d0e8c8
|