Documentation ¶
Overview ¶
Package checkpoints implements checkpoint management: saving and loading of checkpoints.
The main object is the Handler, that should be created by calling Build, followed by the various options setting and finally calling Config.Done. Once create, if a previous saved checkpoint exists, it will automatically load variables and parameters for your model into Context. And as the model trains, one can call Handler.Save() at any time to save a new checkpoint -- typically one will do that inside train.EveryNSteps().
Example: After creating the Context, it checks if a checkpoint directory was set (`*flagCheckpoint`) and if yes, creates a checkpoints.Handler to save checkpoints every 100 steps, keeping the last `*flagCheckpointKeep` steps.
```
… ctx := context.NewContext(manager) ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate) var checkpoint *checkpoints.Handler if *flagCheckpoint != "" { var err error checkpoint, err = checkpoints.Build(ctx).Dir(*flagCheckpoint).Keep(*flagCheckpointKeep).Done() Must(err) // Panics if err != nil. } … // Build training loop. loop := train.NewLoop(trainer) commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop. if checkpoint != nil { const priority = 100 // Large number here, means it runs last. train.EveryNSteps(loop, 100, "checkpointing", priority, checkpoint.OnStepFn) } …
```
TODO:
- Compress checkpoints.
- Allow to specify parts of the model to load / scope where they should be loaded to, for transfer learning.
Index ¶
- Variables
- type Config
- func (c *Config) Dir(dir string) *Config
- func (c *Config) DirFromBase(dir, baseDir string) *Config
- func (c *Config) Done() (*Handler, error)
- func (c *Config) ExcludeParams() *Config
- func (c *Config) ExcludeVarsFromSaving(vars ...*context.Variable) *Config
- func (c *Config) Immediate() *Config
- func (c *Config) Keep(n int) *Config
- func (c *Config) MustDone() *Handler
- func (c *Config) TakeMean(n int) *Config
- func (c *Config) TempDir(dir, pattern string) *Config
- type Handler
- func (h *Handler) Dir() string
- func (h *Handler) HasCheckpoints() (bool, error)
- func (h *Handler) ListCheckpoints() (checkpoints []string, err error)
- func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value tensor.Tensor, found bool)
- func (h *Handler) LoadedVariables() map[string]tensor.Tensor
- func (h *Handler) OnStepFn(_ *train.Loop, _ []tensor.Tensor) error
- func (h *Handler) Save() error
- func (h *Handler) String() string
Constants ¶
This section is empty.
Variables ¶
var ( // DirPermMode is the default directory creation permission (before umask) used. DirPermMode = os.FileMode(0770) )
Functions ¶
This section is empty.
Types ¶
type Config ¶
type Config struct {
// contains filtered or unexported fields
}
Config for the checkpoints Handler to be created. This is created with Build() and configured with the various methods. Once finished, call Done() and it will output a checkpoints.Handler that loads (if there are any previously saved checkpoints) and saves checkpoints.
func Build ¶
Build a configuration for building a checkpoints.Handler. After configuring the Config object returned, call `Done` to get the configured checkpoints.Handler.
func (*Config) Dir ¶
Dir sets the directory where to save / load the checkpoints.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
func (*Config) DirFromBase ¶ added in v0.5.0
DirFromBase sets the directory where to save / load the checkpoints. If `dir` is not an absolute path, assumes it is a subdirectory of baseDir.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
func (*Config) Done ¶
Done creates a Handler with the current configuration. It returns an error if the configuration is invalid, or if it's missing information.
func (*Config) ExcludeParams ¶
ExcludeParams configures Handler to exclude the Context parameters (values usually read/written by Context.GetParam and context.SetParam).
By default, Params are loaded and set into Context the moment Handler is created (when Done() is called), overriding values already present in the Context.
func (*Config) ExcludeVarsFromSaving ¶ added in v0.9.0
ExcludeVarsFromSaving enumerate variables to be excluded from saving. The function can be called multiple times, adding variables to be excluded from saving.
func (*Config) Immediate ¶ added in v0.9.0
Immediate forces immediate load of all variables, as opposed to dynamically load variables from checkpoint as they are being used when building the model.
func (*Config) Keep ¶
Keep configures the number of checkpoint files to keep. If set to -1, it will never erase older checkpoints. The default is 1.
func (*Config) MustDone ¶
MustDone constructs the checkpoints.Handler. It panics if there was an error.
func (*Config) TakeMean ¶ added in v0.4.1
TakeMean loads the mean of the last `n` checkpoints. If `n <= 0`, take the mean of all available checkpoints. Notice that only trainable variables are averaged. Variables that have integer values or are not marked as trainable (e.g. the global step), are taken from the most recent checkpoint instead.
The default is 1, so only load the most recent checkpoint.
Notice the mean is taken one tensor at a time, so at any time there is only one copy of the model weights in memory, plus the tensor being merged.
func (*Config) TempDir ¶
TempDir creates a temporary directory under dir, with the pattern name, and uses this directory to load / save checkpoints. It's a convenience wrapper to os.MkdirTemp.
If dir is the empty string, MkdirTemp uses the default directory for temporary files, as returned by os.TempDir.
The new directory's name is generated by adding a random string to the end of pattern. If `pattern` includes a "*", the random string replaces the last "*" instead (see os.MkdirTemp).
Any errors are reported on the return to the call to the method Done.
One must be set either Dir, DirFromBase or TempDir before building the checkpoints.Handler.
type Handler ¶
type Handler struct {
// contains filtered or unexported fields
}
Handler handles saving and loading of checkpoints for a context.Context. See example in package documentation.
It is created and configured using Build(), followed by options setting and then calling Config.Done().
Loading data into Handler happens at its creation time: it loads from the latest checkpoint. (Hyper-)Parameters are immediately loaded into the context then (if not Config.ExcludeParams) but the loaded variable values are only "consumed" (used) one at a time, as the variables are created during the graph building (e.g: when building the model).
Saving of checkpoints is explicit, by calling Handler.Save(). Usually this is done by configuring train.Loop to call it using train.EveryNSteps or train.NTimesDuringLoop. When saving all variables in Context are saved, along with any previous variables loaded by the Handler that were not used by Context and with the `Params` for all scopes (including changed values).
There can be more than one Handler attached to a Context -- they are used for loading in order they are created (so the first one created takes priority). Multiple Handler set up can be used for instance for transfer learning, where parts of the model are loaded from somewhere else.
A Handler can only be "attached" to one context.Context. If one wants to load the same checkpoint to two different contexts, another Handler object needs to be created. This is because once a variable is loaded, it is transferred to Context, and handler does not keep it.
func (*Handler) Dir ¶
Dir returns the directory the Handler is configured to. It cannot be changed once the Handler was created.
It returns "" (empty) if the Handler is `nil`.
func (*Handler) HasCheckpoints ¶ added in v0.4.0
HasCheckpoints returns whether there are any checkpoints saved.
func (*Handler) ListCheckpoints ¶
ListCheckpoints returns the base file name of the checkpoints in the directory in time order (older first).
func (*Handler) LoadVariable ¶
func (h *Handler) LoadVariable(ctx *context.Context, scope, name string) (value tensor.Tensor, found bool)
LoadVariable implements context.Loader. This is called by context.Context when the variable is used for the first time. The user may want to use this function to inspect loaded values for testing.
func (*Handler) LoadedVariables ¶ added in v0.4.1
LoadedVariables for inspection. These are the values loaded -- but not necessarily immediately available in context, since they are actually used only when a model asks for the variable.
The Handler owns the returned map, don't change it -- the behavior is undefined if you do.
func (*Handler) OnStepFn ¶ added in v0.4.0
OnStepFn implements `train.OnStepFn`, and make it convenient to attach to a training loop. It simply calls save.
func (*Handler) Save ¶
Save creates a new checkpoint and save the context variables and (optionally) Params.
All variables in the context are saved, as well as those previously loaded -- this allows one to load the variables only for a part of the model, update that part and save again with everything.
Params is (de-) serialized with package json.