train

package
v0.9.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 20, 2024 License: Apache-2.0 Imports: 15 Imported by: 0

Documentation

Overview

Package train holds tools to help run a training loop.

It provides various levels of tooling, from supporting training one step, or a full loop. But it should serve also as an example that users that needs more flexibility can start from to create their own training loops.

Index

Constants

View Source
const (
	// TrainerAbsoluteScope used for Context parameters related to the trainer.
	TrainerAbsoluteScope = context.ScopeSeparator + "trainer"

	// TrainerLossGraphParamKey is the key to the global params that holds the
	// losses added to the model.
	TrainerLossGraphParamKey = "trainer_loss"

	// TrainerPerStepUpdateGraphFnParamKey is used by AddPerStepUpdateGraphFn.
	TrainerPerStepUpdateGraphFnParamKey = "trainer_per_step_update_graph_fn"
)

Variables

View Source
var DefaultMaxExecutors = 20

DefaultMaxExecutors used for Trainer objects. Each different `spec` value from a Dataset triggers the creation of a new executor.

Functions

func AddLoss

func AddLoss(ctx *context.Context, loss *graph.Node)

AddLoss adds the given scalar loss (if it is not scalar, it will be reduced with ReduceAllMean) to the context's Params. This is the loss used by the trainer to optimize the model.

This function can be called multiple times and the loss is accumulated.

If `loss` is not scalar (often one doesn't reduce the batch axis), it is automatically reduced with `graph.ReduceAllMean`.

If you are only providing loss terms with AddLoss, you can pass nil as the `lossFn` parameter to the Trainer.

func AddPerStepUpdateGraphFn

func AddPerStepUpdateGraphFn(ctx *context.Context, g *graph.Graph, fn ContextGraphFn)

AddPerStepUpdateGraphFn registers the given function to be executed at every training step, after optimizer updates the variables with the gradient.

This allows one for instance to implement variable constraints.

There can be one ContextGraphFn registered per scope per graph. If one wants to register more than one such functions, use a different context scopes.

One thing to observe: this is executed after the optimizer and therefore also after the loss is calculated. Any changes made to the model weights won't be reflected on the loss returned by the training step. Nor most of the metrics: the metrics are updated after this hook, but they typically use the predictions that were also generated earlier in the training.

func EveryNSteps

func EveryNSteps(loop *Loop, n int, name string, priority Priority, fn OnStepFn)

EveryNSteps registers a OnStep hook on the loop that is called every N times.

Notice that it does not call `fn` at the last step (except by coincidence).

func ExponentialCallback added in v0.4.0

func ExponentialCallback(loop *Loop, startStep int, exponentialFactor float64, callOnEnd bool,
	name string, priority Priority, fn OnStepFn)

ExponentialCallback registers an `OnStep` hook on the loop that is called at exponentially increasing number of steps in between, starting with startStep, and growing at geometric factor of exponentialFactor.

If callOnEnd is set, it will also call at the end of the loop.

Example: This will call at steps 100, 100+100*1.2 = 220, 220+100*1.2^2 = 364, ...

ExponentialCallback(loop, 100, 1.2, "my_callback", 100, myCallback)

func GetLosses

func GetLosses(ctx *context.Context, g *graph.Graph) (loss *graph.Node)

GetLosses returns the sum of all loss terms added with AddLoss(), or nil if none was set.

Usually this is used by the trainer after all losses are accounted for. But can be used by arbitrary modeling functions. In particular after the optimizer update, see AddPerStepUpdateGraphFn.

func NTimesDuringLoop

func NTimesDuringLoop(loop *Loop, n int, name string, priority Priority, fn OnStepFn)

NTimesDuringLoop registers a OnStep hook on the loop that is called at most N times, split evenly across all steps.

For Loop.RunEpochs it does not work perfectly even, at least until it knows what is the exact number of steps -- it may even call OnStepFn more than n times.

It always calls `fn` at the very last step.

func PeriodicCallback added in v0.4.0

func PeriodicCallback(loop *Loop, period time.Duration, callOnEnd bool, name string, priority Priority, fn OnStepFn)

PeriodicCallback registers an `OnStep` hook on the loop that is called every period of time. The period counts after the execution of `OnStep`: this discounts the time to run `OnStep` (in case it is expensive) and it discounts cases where the execution is paused. By other hand, OnStep is not executed exactly at every `period` time.

If callOnEnd is set, it will also call at the end of the loop.

Types

type ContextGraphFn

type ContextGraphFn func(ctx *context.Context, g *graph.Graph)

ContextGraphFn is a generic graph building function.

type Dataset

type Dataset interface {
	// Name identifies the dataset. Used for debugging, pretty-printing and plots.
	Name() string

	// Yield one "batch" (or whatever is the unit for a training step) or an error. It should return a
	// `spec` for the dataset, a slice of `inputs` and a slice of `labels` tensors (even when there is only
	// one tensor for each of them).
	//
	// In the simplest case `inputs` and `labels` should always have the same number of elements and the
	// same shape (including `dtype`).
	//
	// If the `inputs` or `labels` change shapes during training or evaluation, it will trigger the creation
	// of a new computation graph and new JIT (just-in-time) compilation. There is a finite-sized cache,
	// and this can become inefficient -- it may spend more time JIT compiling than executing code. Consider
	// instead using padding for things that have variable length. And if there are various such elements,
	// consider padding to powers of 2 (or some other base) to limit the number of shapes that will be used.
	//
	// Yield also returns an opaque `spec` object that is normally simply passed to the model function
	// -- it can simply be nil. The `spec` usually is static (always the same) for a dataset. E.g.: the field names
	// and types of a generic CSV file reader.
	//
	// **Important**:
	// 1. For train.Trainer the `spec` is converted to string as a key of a `map[string]` for different computation
	//    graphs, so each time the `spec` changes, the model graph is regenerated and re-compiled. Just like with
	//   `inputs` or `labels` of different shapes.
	// 2. The number of `inputs` and `labels` should not change for the same `spec` -- the train.Trainer will return
	//    an error if they do. Their shape can vary (at the cost of creating a new JIT-compiled graph for each
	//    different combination of shapes). If the number of `inputs` or `labels` needs changing, a new `spec`
	//    needs to be given.
	//
	// If using Loop.RunSteps for training having an infinite dataset stream is ok. But careful
	// not to use Loop.RunEpochs on a dataset configured to loop indefinitely.
	//
	// Optionally it can return an error. If the error is `io.EOF` the training/evaluation terminates
	// normally, as it indicates end of data for finite datasets -- maybe the end of the epoch.
	//
	// Any other errors should interrupt the training/evaluation and be returned to the user.
	Yield() (spec any, inputs []tensor.Tensor, labels []tensor.Tensor, err error)

	// Reset restarts the dataset from the beginning. Can be called after io.EOF is reached,
	// for instance when running another evaluation on a test dataset.
	Reset()
}

Dataset for a train.Trainer provides the data, one batch at a time. Flat consists of a slice of tensor.Tensor for `inputs` and for `labels`.

Dataset has to also provide a Dataset.Name() and a dataset `spec`, which usually is the same for the whole dataset, but can vary per batch, if the Dataset is yielding different types of data.

For a static Dataset that always provides the exact same data, the `spec` can simply be nil.

Notice one batch (the unit of data) is a slice of tensors for inputs and one tensor for labels.

type GraphType added in v0.4.1

type GraphType int

GraphType can be TrainGraph or EvalGraph, when there needs to be a distinction.

const (
	TrainType GraphType = iota
	EvalType
)

type Loop

type Loop struct {
	// Trainer associated with this loop. In particular Trainer.TrainMetrics() and
	// Trainer.EvalMetrics() can be of interest.
	// TODO: make Trainer an interface, so Loop can work on custom trainers..
	Trainer *Trainer

	// LoopStep currently being executed.
	// Defaults to the current context's `GlobalStep`, which will be 0 for a new context.
	LoopStep int

	// StartStep is the value of LoopStep at the start of a run (RunSteps or RunEpochs).
	StartStep int

	// EndStep is one-past the last step to be executed. If -1 the end step is not known (if
	// running till the end of the dataset). When running for multiple epochs (Loop.RunEpochs) it can
	// change during the run (after the first epoch, the value is extrapolated based on how many steps
	// have been run so far).
	EndStep int

	// Epoch is set when running Loop.RunEpochs() to the current running epoch, starting from 0.
	Epoch int

	// SharedData allows for cross-tools to publish and consume information. Keys (strings)
	// and semantics/type of their values are not specified by loop.
	SharedData map[string]any

	// trainStepDurations collected during training
	TrainStepDurations []time.Duration
	// contains filtered or unexported fields
}

Loop will run a training loop, invoking Trainer.TrainStep every step, and calling the appropriate hooks.

It also converts graph building errors thrown with `panic` and return them instead as normal errors.

By itself it doesn't do much, but one can attach functionality to it, like checkpointing, plotting tools, early-stopping strategies, etc. It is simple and flexible to allow arbitrary tools to the training loop.

The public attributes are meant for reading only, don't change them -- behavior can be undefined.

func NewLoop

func NewLoop(trainer *Trainer) *Loop

NewLoop creates a new training loop trainer.

func (*Loop) MedianTrainStepDuration added in v0.4.0

func (loop *Loop) MedianTrainStepDuration() time.Duration

MedianTrainStepDuration returns the median duration of each training step. It returns 1 millisecond if no training step was recorded (to avoid potential division by 0).

It sorts and mutates loop.TrainStepDurations.

func (*Loop) OnEnd

func (loop *Loop) OnEnd(name string, priority Priority, fn OnEndFn)

OnEnd adds a hook with given priority and name (for error reporting) to the end of a loop, after the last call to `Trainer.TrainStep`.

func (*Loop) OnStart

func (loop *Loop) OnStart(name string, priority Priority, fn OnStartFn)

OnStart adds a hook with given priority and name (for error reporting) to the start of a loop.

func (*Loop) OnStep

func (loop *Loop) OnStep(name string, priority Priority, fn OnStepFn)

OnStep adds a hook with given priority and name (for error reporting) to each step of a loop. The function `fn` is called after each `Trainer.TrainStep`.

func (*Loop) RunEpochs

func (loop *Loop) RunEpochs(ds Dataset, epochs int) (metrics []tensor.Tensor, err error)

RunEpochs runs those many steps. StartStep is adjusted to the current LoopStep, so it can be called multiple times, and it will simply pick up where it left of last time. Loop.Epoch is set to the current running epoch. EndStep starts as -1 and will be adjusted to expectation after the first epoch, when one knows how many steps there are going to be. Dataset.Reset is called after each epoch (including the last).

func (*Loop) RunSteps

func (loop *Loop) RunSteps(ds Dataset, steps int) (metrics []tensor.Tensor, err error)

RunSteps runs those many steps. StartStep and EndStep are adjusted to the current LoopStep, so it can be called multiple times, and it will simply pick up where it left of last time.

type LossFn

type LossFn func(labels, predictions []*graph.Node) *graph.Node

LossFn takes the output of ModelFn (called predictions, but it could be the logits), and the labels (coming out of Dataset.Yield()), and outputs the scalar loss, that can be used for training.

For some types of self-supervised models for which there are no labels, the labels can be empty.

Most of the predefined losses in package `gomlx/ml/train/losses` assume labels and predictions are both of length one. For multi-head models, it's very easy to write a small custom LossFn that splits the slice and send each label/prediction pair to a predefined loss.

type ModelFn

type ModelFn func(ctx *context.Context, spec any, inputs []*graph.Node) (predictions []*graph.Node)

ModelFn is a computation graph building function that takes as input a `spec` and a slice of `inputs` (even if just one) generated by a Dataset, and as output a slice (even if only one) of the `predictions` (or sometimes the logits).

The `predictions` output by ModelFn is fed to a LossFn and to MetricsFn during training.

Notice `spec` is opaque to train package, it's passed from the train.Dataset to the ModelFn, and its meaning is determined by the train.Dataset used. For static case (where data is always the same) it can simply be nil. Each value of `spec` is mapped to different computation graphs by the train.Trainer.

type OnEndFn

type OnEndFn func(loop *Loop, metrics []tensor.Tensor) error

OnEndFn is the type of OnEnd hooks.

type OnExecFn added in v0.4.1

type OnExecFn func(exec *context.Exec, graphType GraphType)

OnExecFn is a handler that can be called when executors are created. See Train.OnExecCreation.

type OnStartFn

type OnStartFn func(loop *Loop, ds Dataset) error

OnStartFn is the type of OnStart hooks.

type OnStepFn

type OnStepFn func(loop *Loop, metrics []tensor.Tensor) error

OnStepFn is the type of OnStep hooks.

type Priority

type Priority int

Priority for hooks, the lowest values are run first. Defaults to 0, but negative values are ok.

type Trainer

type Trainer struct {
	// contains filtered or unexported fields
}

Trainer is a helper object to orchestrate a training step and evaluation.

Given the inputs and labels, it deals with executing a training step (TrainStep) and evaluation (EvalStep and Eval), calling the loss and optimizer and running metrics.

See Loop for a flexible and extensible (different UIs) way to run this in a training loop.

func NewTrainer

func NewTrainer(manager *graph.Manager, ctx *context.Context,
	modelFn ModelFn, lossFn LossFn, optimizer optimizers.Interface,
	trainMetrics, evalMetrics []metrics.Interface) *Trainer

NewTrainer constructs a trainer that can be used for training steps and evaluation. It also creates a new Context for model, which will hold the variables, hyperparameters and other information. It can be changed by the user.

Its arguments are:

  • manager needed to create and compile computation graphs.

  • ctx (will) hold the variables, hyperparameters and related information for the model.

  • modelFn builds the graph that transforms inputs into predictions (or logits).

  • lossFn takes the predictions (the output of modelFn) and the labels and outputs the loss. If the returned loss is not a scalar, it will be ReduceAllMean to a scalar. There are several standard losses available in gomlx/ml/train/losses package. They can simply be used as is, or called by arbitrary custom losses. It can also be set to nil, if one is providing loss terms with `AddLoss` -- e.g.: for unsupervised training.

  • optimizer (e.g: optimizers.StochasticGradientDescent) is the methodology to improve the model variables (aka. parameters or weights) to minimize the loss (the output of lossFn), typically using gradient descent.

  • trainMetrics are output by trainer.TrainStep after each step. Here it's recommended to use moving average types of metrics, since the model is changing so a mean wouldn't make sense. The mean loss of the batch and a moving average of the loss is always included (the first two) by default. It's ok to be empty (nil).

  • evalMetrics are output by trainer.EvalStep and trainer.Eval. Here it's recommend to use mean metrics, since the model is presumably frozen, and it sees each example exactly once. The mean of the loss of the dataset is always provided as the first metric. It's ok to be empty (nil).

func (*Trainer) Context

func (r *Trainer) Context() *context.Context

Context returns the current Context. See SetContext to change it.

func (*Trainer) Eval

func (r *Trainer) Eval(ds Dataset) (lossAndMetrics []tensor.Tensor)

Eval returns the computation of loss and metrics over the given dataset. The dataset has to be finite (yield io.EOF at the end). The function will reset the dataset at the start.

func (*Trainer) EvalMetrics

func (r *Trainer) EvalMetrics() []metrics.Interface

EvalMetrics returns the eval metrics objects (not the actual values just the objects that implement them).

func (*Trainer) EvalStep

func (r *Trainer) EvalStep(spec any, inputs, labels []tensor.Tensor) (metrics []tensor.Tensor)

EvalStep runs one eval step and returns the metrics, the first one being the mean loss.

The parameters are the output of a Dataset.Yield call. The same as TrainStep.

It returns the current value for the registered eval metrics.

Errors are thrown using `panic` -- they are usually informative and include a stack-trace.

func (*Trainer) InDevice

func (r *Trainer) InDevice(deviceNum int) *Trainer

InDevice sets the device num to be used when executing graphs. TODO: Add support for training across multiple devices -- maybe a different Trainer for that, in principle should be simple. This should be called before any invocations of TrainStep. It returns a reference to itself so calls can be cascaded.

func (*Trainer) Metrics

func (r *Trainer) Metrics() []metrics.Interface

Metrics return list of registered eval metrics, including the loss metric that is added automatically.

func (*Trainer) OnExecCreation added in v0.4.1

func (r *Trainer) OnExecCreation(handler OnExecFn)

OnExecCreation registers a handler to be called each time an executor (`context.Exec`) is created by the trainer. Different executors are create for training and eval (`train` reflect that), and for different `spec` values received from the Dataset. The `handler` is also given the mode (TrainGraph or EvalGraph) the executor is created for.

func (*Trainer) ResetTrainMetrics

func (r *Trainer) ResetTrainMetrics() error

ResetTrainMetrics call Metrics.Reset on all train metrics. Usually called before a training session.

func (*Trainer) SetContext

func (r *Trainer) SetContext(ctx *context.Context) *Trainer

SetContext associates the given Context to the trainer. Should be called before any calls to Train or Evaluate. Notice that after the first time context is used to build a graph, it is set to Reuse. If the Context variables were already created, it should be marked with Context.Reuse. It returns a reference to itself so calls can be cascaded.

func (*Trainer) TrainMetrics

func (r *Trainer) TrainMetrics() []metrics.Interface

TrainMetrics returns the train metrics objects (not the actual values just the objects that implement them).

func (*Trainer) TrainStep

func (r *Trainer) TrainStep(spec any, inputs, labels []tensor.Tensor) (metrics []tensor.Tensor)

TrainStep runs one step and returns the metrics.

All arguments usually come from `Dataset.Yield`, see a more detailed description there. In short:

  • info: provided by the dataset. Often just nil. Each value will trigger the creation of different computation graphs. Normally static values (for the dataset) used to describe the inputs. See longer discussion in `train.Dataset`.
  • inputs: always a slice, even though it's common to have only one input tensor in the slice. There must be always at least one input. For each `info` value, the number of inputs and labels must remain constant. It will return an error otherwise.
  • labels: also always a slice, even if commonly with only one tensor.

It returns a slice of metrics, that includes (the first two) the batch loss, and the moving exponential average of the batch loss, plus the other `trainMetrics` configured during the creation of the Trainer.

Errors are thrown using `panic` -- they are usually informative and include a stack-trace.

Directories

Path Synopsis
Package commandline contains convenience UI training tools for the command line.
Package commandline contains convenience UI training tools for the command line.
Package losses have several standard losses that implement train.LossFn interface.
Package losses have several standard losses that implement train.LossFn interface.
Package metrics holds a library of metrics and defines
Package metrics holds a library of metrics and defines
Package optimizers implements a collection of ML optimizers, that can be used by train.Trainer, or by themselves.
Package optimizers implements a collection of ML optimizers, that can be used by train.Trainer, or by themselves.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL