trials

package
v0.0.0-...-3511abf Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 2, 2023 License: Apache-2.0 Imports: 21 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func CanGetTrialsExperimentAndCheckCanDoAction

func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context,
	trialID int, actionFunc func(context.Context, model.User, *model.Experiment) error,
) error

CanGetTrialsExperimentAndCheckCanDoAction is a utility function for generalizing RBAC support for trials and experiments.

func CreateTrialSourceInfo

func CreateTrialSourceInfo(ctx context.Context, tsi *trialv1.TrialSourceInfo,
) (*apiv1.ReportTrialSourceInfoResponse, error)

CreateTrialSourceInfo creates a TrialSourceInfo object, which allows us to keep track of the linkage between an inference/fine tuning trial and its checkpoint/model version.

func GetMetricsForTrialSourceInfoQuery

func GetMetricsForTrialSourceInfoQuery(
	ctx context.Context, q *bun.SelectQuery,
	groupName *string,
) ([]*trialv1.MetricsReport, error)

GetMetricsForTrialSourceInfoQuery takes in a bun.SelectQuery on the trial_source_infos table, and fetches the metrics for each of the connected trials.

func LatestCheckpointForTrialTx

func LatestCheckpointForTrialTx(ctx context.Context, idb bun.IDB, trialID int) (
	*model.Checkpoint, error,
)

LatestCheckpointForTrialTx finds the latest completed checkpoint for a trial, returning nil if none exists.

func MarkLostTrials

func MarkLostTrials(ctx context.Context) error

MarkLostTrials marks the trials which did not have a heartbeat for more than 5 minutes as errored.

func MarkLostTrialsWorker

func MarkLostTrialsWorker(ctx context.Context)

MarkLostTrialsWorker runs `MarkLostTrials` every 5 minutes.

func MetricsTimeSeries

func MetricsTimeSeries(trialID int32, startTime time.Time,
	metricNames []string,
	startBatches int, endBatches int, xAxisMetricLabels []string,
	maxDatapoints int, timeSeriesColumn string,
	timeSeriesFilter *commonv1.PolymorphicFilter, metricGroup model.MetricGroup) (
	metricMeasurements []db.MetricMeasurements, err error,
)

MetricsTimeSeries returns a time-series of the specified metric in the specified trial.

func ProtoGetTrialsPlusTx

func ProtoGetTrialsPlusTx(
	ctx context.Context, idb bun.IDB, trialIDs []int,
) ([]*trialv1.Trial, error)

ProtoGetTrialsPlusTx does the `proto_get_trials_plus` thing.

func UpdateUnmanagedExperimentStatesTx

func UpdateUnmanagedExperimentStatesTx(
	ctx context.Context, tx bun.IDB, experiments []*model.Experiment,
) error

UpdateUnmanagedExperimentStatesTx updates an [unmanaged] experiment state according to its constituent trial states.

Types

type Trial

type Trial struct {
	bun.BaseModel         `bun:"table:trials"`
	ID                    int            `bun:"id,pk,autoincrement"`
	ExperimentID          int            `bun:"experiment_id"`
	State                 model.State    `bun:"state"`
	StartTime             time.Time      `bun:"start_time"`
	EndTime               *time.Time     `bun:"end_time"`
	Hparams               map[string]any `bun:"hparams"`
	WarmStartCheckpointID *int           `bun:"warm_start_checkpoint_id"`
	Seed                  int            `bun:"seed"`
	RequestID             *string        `bun:"request_id"`
	BestValidationID      *int           `bun:"best_validation_id"`
	// TODO(ilia): enum for training/validating/checkpointing.
	RunnerState string `bun:"runner_state"`
	RunID       int    `bun:"run_id"`
	Restarts    int    `bun:"restarts"`
	// Note: Tags map values are always "".
	Tags                      map[string]string `bun:"tags"`
	CheckpointSize            int               `bun:"checkpoint_size"`
	CheckpointCount           int               `bun:"checkpoint_count"`
	SearcherMetricValue       *float64          `bun:"searcher_metric_value"`
	SearcherMetricValueSigned *float64          `bun:"searcher_metric_value_signed"`
	TotalBatches              int               `bun:"total_batches"`
	// TODO(ilia): better typing for SummaryMetrics.
	SummaryMetrics          map[string]any `bun:"summary_metrics"`
	SummaryMetricsTimestamp *time.Time     `bun:"summary_metrics_timestamp"`
	LatestValidationID      int            `bun:"latest_validation_id"`
	LastActivity            *time.Time     `bun:"last_activity"`
	ExternalTrialID         *string        `bun:"external_trial_id"`
}

Trial is a better bun trial model than the one in pkg/model/experiment.go.

type TrialSourceInfoAPIServer

type TrialSourceInfoAPIServer struct{}

TrialSourceInfoAPIServer is a dummy struct to do dependency injection to the api. This allows us to define apiServer functions in sub-modules.

func (*TrialSourceInfoAPIServer) ReportTrialSourceInfo

ReportTrialSourceInfo creates a TrialSourceInfo, which serves as a link between trials and checkpoints used for tracking purposes for fine tuning and inference.

type TrialsAPIServer

type TrialsAPIServer struct{}

TrialsAPIServer is an embedded api server struct.

func (*TrialsAPIServer) StartTrial

StartTrial is called on Core API context enter in detached mode.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL