Skip to main content

A Flax trainer

Project description

XTRAIN: a tiny library for training Flax models.

Design goals:

  • Help avoiding boiler-plate code
  • Minimal functionality and dependency
  • Agnostic to hardware configuration (e.g. GPU->TPU)

General workflow

Step 1: define your model

class MyFlaxModule(nn.Module):
  def __call__(self, x):
    ...

Step 2: define loss function

def my_loss_func(**kwargs):
    model_out = kwargs["preds"]
    labels = kwargs["labels"]
    loss = ....
    return loss

Step 3: create an iterator that supplies training data

my_data = itertools.cycle(
    zip(sequence_of_inputs, sequence_of_labels)
)

Step 4: train

# create and initialize a Trainer object
trainer = xtrain.Trainer(
  model = MyFlaxModule(),
  losses = my_loss_func,
)
trainer.initialize(my_data, tx=my_optax_optimizer)

train_iter = trainer.train(my_data) # returns a python iterator

# iterate the train_iter trains the model
for step in range(train_steps):
    avg_loss = next(train_iter)
    if step // 1000 == 0:
        print(avg_loss)
        trainer.reset()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xtrain-0.1.1.tar.gz (9.0 kB view hashes)

Uploaded Source

Built Distribution

xtrain-0.1.1-py3-none-any.whl (10.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page