blob: e4c86861c54c981a6b30427a303f1d1435e677c7 [file] [log] [blame] [edit]
package state_machine
/*
Simple state machine implementation.
*/
import (
"context"
"fmt"
"sort"
"sync"
"cloud.google.com/go/storage"
"go.skia.org/infra/go/gcs"
)
// TransitionFn is a function to run when attempting to transition from one
// State to another. It is okay to give nil as a noop TransitionFn.
type TransitionFn func(ctx context.Context) error
type transition struct {
from string
to string
fn string
}
// Builder is a helper struct used for constructing StateMachines.
type Builder struct {
funcs map[string]TransitionFn
initialState string
transitions []transition
}
// NewBuilder returns a Builder instance.
func NewBuilder() *Builder {
return &Builder{
funcs: map[string]TransitionFn{},
initialState: "",
transitions: []transition{},
}
}
// Add a transition between the two states with the given named function.
func (b *Builder) T(from, to, fn string) {
b.transitions = append(b.transitions, transition{from: from, to: to, fn: fn})
}
// Add a transition function with the given name. Allows nil functions for no-op
// transitions.
func (b *Builder) F(name string, fn TransitionFn) {
b.funcs[name] = func(ctx context.Context) error {
if fn != nil {
return fn(ctx)
}
return nil
}
}
// Set the initial state.
func (b *Builder) SetInitial(s string) {
b.initialState = s
}
// Build and return a StateMachine instance.
func (b *Builder) Build(ctx context.Context, gcsClient gcs.GCSClient, gsPath string) (*StateMachine, error) {
// Build and validate.
transitions := map[string]map[string]string{}
states := make(map[string]bool, len(b.transitions))
for _, t := range b.transitions {
states[t.from] = true
states[t.to] = true
toMap, ok := transitions[t.from]
if !ok {
toMap = map[string]string{}
transitions[t.from] = toMap
}
if fn, ok := toMap[t.to]; ok {
return nil, fmt.Errorf("Multiple defined transitions from %q to %q: %q, %q", t.from, t.to, fn, t.fn)
}
toMap[t.to] = t.fn
if _, ok := b.funcs[t.fn]; !ok {
return nil, fmt.Errorf("Function %q not defined.", t.fn)
}
}
// Get the previous state (if any) from the file.
cachedState := b.initialState
contents, err := gcsClient.GetFileContents(ctx, gsPath)
if err == nil {
cachedState = string(contents)
} else if err != storage.ErrObjectNotExist {
return nil, fmt.Errorf("Unable to read file for persistentStateMachine: %s", err)
}
if _, ok := states[b.initialState]; !ok {
return nil, fmt.Errorf("Initial state %q is not defined!", b.initialState)
}
// Every state must have transitions defined to and from it, even if
// they are just self-transitions.
for state := range states {
if _, ok := transitions[state]; !ok {
return nil, fmt.Errorf("No transitions defined from state %q", state)
}
}
for _, toMap := range transitions {
for to := range toMap {
delete(states, to)
}
}
for s := range states {
if s != b.initialState {
return nil, fmt.Errorf("No transitions defined to state %q", s)
}
}
// Create and return the StateMachine.
sm := &StateMachine{
current: cachedState,
gcs: gcsClient,
funcs: b.funcs,
transitions: transitions,
file: gsPath,
}
// Write initial state back to GCS, in case it wasn't there before.
if err := sm.gcs.SetFileContents(ctx, sm.file, gcs.FILE_WRITE_OPTS_TEXT, []byte(sm.Current())); err != nil {
return nil, err
}
return sm, nil
}
// StateMachine is a simple state machine implementation which persists its
// current state to a file.
type StateMachine struct {
current string
gcs gcs.GCSClient
funcs map[string]TransitionFn
transitions map[string]map[string]string
file string
mtx sync.RWMutex
}
// Return the current state.
func (sm *StateMachine) Current() string {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return sm.current
}
// Attempt to transition to the given state, using the transition function.
func (sm *StateMachine) Transition(ctx context.Context, dest string) error {
sm.mtx.Lock()
defer sm.mtx.Unlock()
toMap, ok := sm.transitions[sm.current]
if !ok {
return fmt.Errorf("No transitions defined from state %q", sm.current)
}
fName, ok := toMap[dest]
if !ok {
return fmt.Errorf("No transition defined from state %q to state %q", sm.current, dest)
}
fn, ok := sm.funcs[fName]
if !ok {
return fmt.Errorf("Undefined transition function %q", fName)
}
// Run the transition func.
if err := fn(ctx); err != nil {
return fmt.Errorf("Failed to transition from %q to %q: %s", sm.current, dest, err)
}
sm.current = dest
return sm.gcs.SetFileContents(ctx, sm.file, gcs.FILE_WRITE_OPTS_TEXT, []byte(sm.current))
}
// Return the name of the transition function from the current state to the
// given state.
func (sm *StateMachine) GetTransitionName(dest string) (string, error) {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
toMap, ok := sm.transitions[sm.current]
if !ok {
return "", fmt.Errorf("No transitions defined from state %q", sm.current)
}
fName, ok := toMap[dest]
if !ok {
return "", fmt.Errorf("No transition defined from state %q to state %q", sm.current, dest)
}
return fName, nil
}
// ListStates returns the list of all states in the StateMachine.
func (sm *StateMachine) ListStates() []string {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
// Every state must have at least one transition defined from it, so we
// can use the keys in the outer transitions map to list the states.
rv := make([]string, 0, len(sm.transitions))
for state := range sm.transitions {
rv = append(rv, state)
}
sort.Strings(rv)
return rv
}