Skip to main content

A python module. The main function is for pytorch training.

Project description

Introduction

This is an individual module, which is mainly for pytorch CNN training.

Moreover, it also supports some awesome features: saving model, saving training process, plotting figures and so on...

Install

pip install Fau-tools

Usage

import

The following code is recommended.

import Fau_tools
from Fau_tools import torch_tools

quick start

The tutor will use a simple example to help you get started quickly!

The following example uses Fau-tools to train a model in MNIST hand-written digits dataset.

import torch
import torch.utils.data as tdata
import torchvision
from torch import nn

import Fau_tools
from Fau_tools import torch_tools

# A simple CNN network
class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(1, 16, 3, 1, 1),  # -> (16, 28, 28)
      nn.ReLU(),
      nn.MaxPool2d(2),  # -> (16, 14, 14)

      nn.Conv2d(16, 32, 3, 1, 1),  # -> (32, 14, 14)
      nn.ReLU(),
      nn.MaxPool2d(2)  # -> (32, 7, 7)
    )
    self.output = nn.Linear(32 * 7 * 7, 10)

  def forward(self, x):
    x = self.conv(x)
    x = x.flatten(1)  # same as x = x.view(x.size(0), -1)
    return self.output(x)


# Hyper Parameters definition
EPOCH = 10
LR = 1E-3
BATCH_SIZE = 1024

# Load dataset
TRAIN_DATA = torchvision.datasets.MNIST('Datasets', True, torchvision.transforms.ToTensor(), download=True)
TEST_DATA = torchvision.datasets.MNIST('Datasets', False, torchvision.transforms.ToTensor())

# Get data loader
train_loader = tdata.DataLoader(TRAIN_DATA, BATCH_SIZE, True)
test_loader = tdata.DataLoader(TEST_DATA, BATCH_SIZE)

# Initialize model, optimizer and loss function
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), LR)
loss_function = nn.CrossEntropyLoss()

# Train!
torch_tools.torch_train(model, train_loader, test_loader, optimizer, loss_function, EPOCH, name="MNIST")
# the last parameter is the name for saving model and training process.

Now, we can run the python file, and the training process will be visualized, just like the following picture.

training_visualization

Three files named MNIST_9846.pth, MNIST_9846.csv and MNIST_9846.txt will be saved.

The first file is the trained model.

The second file records the training process, which you can use matplotlib to visualize it.

The third file saves some hyper parameters about the training.


The above is the primary usage of this tool, but there are also some other snazzy features, which will be introduced later.

END

Hope you could like it! And welcome issues and pull requests.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fau_tools-1.4.5.tar.gz (10.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fau_tools-1.4.5-py3-none-any.whl (13.4 kB view details)

Uploaded Python 3

File details

Details for the file fau_tools-1.4.5.tar.gz.

File metadata

  • Download URL: fau_tools-1.4.5.tar.gz
  • Upload date:
  • Size: 10.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.11.3 Darwin/22.4.0

File hashes

Hashes for fau_tools-1.4.5.tar.gz
Algorithm Hash digest
SHA256 4b598c2db2f2fc74c8b5e77284a6a87965531f8779aae1167469318645b0d9bf
MD5 5c4294d6ade71fba6202bdf5086f23d7
BLAKE2b-256 2857f4c462fb91ffdb4673009d6a7fe88a992a2f944ac22bff383bcb6c429e25

See more details on using hashes here.

File details

Details for the file fau_tools-1.4.5-py3-none-any.whl.

File metadata

  • Download URL: fau_tools-1.4.5-py3-none-any.whl
  • Upload date:
  • Size: 13.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.2 CPython/3.11.3 Darwin/22.4.0

File hashes

Hashes for fau_tools-1.4.5-py3-none-any.whl
Algorithm Hash digest
SHA256 197a51ae7a472a3b85c6f34551e07901cd78a00f02c6d0471c60d04934183da0
MD5 00678f66db296826031cd79479ac9ffa
BLAKE2b-256 c2c412be36b695dd0914a8dad50154cdea2bb41e297600a96b1d447b87787b88

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page