Skip to main content

A python module. The main function is for pytorch training.

Project description

Introduction

This is an individual module, which is mainly for pytorch CNN training.

Moreover, it also supports some awesome features: saving model, saving training process, plotting figures and so on...

Install

pip install fau-tools

Usage

import

The following code is recommended.

import fau_tools
from fau_tools import torch_tools

quick start

The tutor will use a simple example to help you get started quickly!

The following example uses Fau-tools to train a model in MNIST hand-written digits dataset.

import torch
import torch.utils.data as tdata
import torchvision
from torch import nn

import fau_tools
from fau_tools import torch_tools

# A simple CNN network
class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(1, 16, 3, 1, 1),  # -> (16, 28, 28)
      nn.ReLU(),
      nn.MaxPool2d(2),  # -> (16, 14, 14)

      nn.Conv2d(16, 32, 3, 1, 1),  # -> (32, 14, 14)
      nn.ReLU(),
      nn.MaxPool2d(2)  # -> (32, 7, 7)
    )
    self.output = nn.Linear(32 * 7 * 7, 10)

  def forward(self, x):
    x = self.conv(x)
    x = x.flatten(1)  # same as x = x.view(x.size(0), -1)
    return self.output(x)


# Hyper Parameters definition
EPOCH = 10
LR = 1E-3
BATCH_SIZE = 1024

# Load dataset
TRAIN_DATA = torchvision.datasets.MNIST('Datasets', True, torchvision.transforms.ToTensor(), download=True)
TEST_DATA = torchvision.datasets.MNIST('Datasets', False, torchvision.transforms.ToTensor())

# Get data loader
train_loader = tdata.DataLoader(TRAIN_DATA, BATCH_SIZE, True)
test_loader = tdata.DataLoader(TEST_DATA, BATCH_SIZE)

# Initialize model, optimizer and loss function
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), LR)
loss_function = nn.CrossEntropyLoss()

# Train!
torch_tools.torch_train(model, train_loader, test_loader, optimizer, loss_function, EPOCH, name="MNIST")
# the last parameter is the name for saving model and training process.

Now, we can run the python file, and the training process will be visualized, just like the following picture.

training_visualization

Three files named MNIST_9846.pth, MNIST_9846.csv and MNIST_9846.txt will be saved.

The first file is the trained model.

The second file records the training process, which you can use matplotlib to visualize it.

The third file saves some hyper parameters about the training.


The above is the primary usage of this tool, but there are also some other snazzy features, which will be introduced later.

END

Hope you could like it! And welcome issues and pull requests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fau_tools-1.5.0.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fau_tools-1.5.0-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file fau_tools-1.5.0.tar.gz.

File metadata

  • Download URL: fau_tools-1.5.0.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.11.3 Darwin/22.5.0

File hashes

Hashes for fau_tools-1.5.0.tar.gz
Algorithm Hash digest
SHA256 3221479d7f9ce7c6e150ecbc0de7b65d8b822907db30762168de771644d1e338
MD5 8c6300e73cf9e652786149863374d37c
BLAKE2b-256 305392dd0153a336be336853e72dc6b7e8839760d9b19b1c521ac77e9a42afe5

See more details on using hashes here.

File details

Details for the file fau_tools-1.5.0-py3-none-any.whl.

File metadata

  • Download URL: fau_tools-1.5.0-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.11.3 Darwin/22.5.0

File hashes

Hashes for fau_tools-1.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a2178d14a5cbdb1a8dec77db87c65c1260bdb9306ebfa653566a374ad621d120
MD5 ae9954df9a435e58aee4df053ecaa410
BLAKE2b-256 79eab47bde4fa803bc43a0d244c5443d25dc283b03fd2d1bb12b8a446ee77afe

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page