blob: 4a50f59a3ac67cd5c0c01b466d7aeaec2f95dfa4 [file] [log] [blame]
package state_machine
/*
Simple state machine implementation.
*/
import (
"fmt"
"io/ioutil"
"os"
"sync"
)
// TransitionFn is a function to run when attempting to transition from one
// State to another. It is okay to give nil as a noop TransitionFn.
type TransitionFn func() error
type transition struct {
from string
to string
fn string
}
// Builder is a helper struct used for constructing StateMachines.
type Builder struct {
funcs map[string]TransitionFn
initialState string
transitions []transition
}
// NewBuilder returns a Builder instance.
func NewBuilder() *Builder {
return &Builder{
funcs: map[string]TransitionFn{},
initialState: "",
transitions: []transition{},
}
}
// Add a transition between the two states with the given named function.
func (b *Builder) T(from, to, fn string) {
b.transitions = append(b.transitions, transition{from: from, to: to, fn: fn})
}
// Add a transition function with the given name. Allows nil functions for no-op
// transitions.
func (b *Builder) F(name string, fn func() error) {
b.funcs[name] = func() error {
if fn != nil {
return fn()
}
return nil
}
}
// Set the initial state.
func (b *Builder) SetInitial(s string) {
b.initialState = s
}
// Build and return a StateMachine instance.
func (b *Builder) Build(file string) (*StateMachine, error) {
// Build and validate.
transitions := map[string]map[string]string{}
states := make(map[string]bool, len(b.transitions))
for _, t := range b.transitions {
states[t.from] = true
states[t.to] = true
toMap, ok := transitions[t.from]
if !ok {
toMap = map[string]string{}
transitions[t.from] = toMap
}
if fn, ok := toMap[t.to]; ok {
return nil, fmt.Errorf("Multiple defined transitions from %q to %q: %q, %q", t.from, t.to, fn, t.fn)
}
toMap[t.to] = t.fn
if _, ok := b.funcs[t.fn]; !ok {
return nil, fmt.Errorf("Function %q not defined.", t.fn)
}
}
// Get the previous state (if any) from the file.
cachedState := b.initialState
contents, err := ioutil.ReadFile(file)
if err == nil {
cachedState = string(contents)
} else if !os.IsNotExist(err) {
return nil, fmt.Errorf("Unable to read file for persistentStateMachine: %s", err)
}
if _, ok := states[b.initialState]; !ok {
return nil, fmt.Errorf("Initial state %q is not defined!", b.initialState)
}
// Every state must have transitions defined to and from it, even if
// they are just self-transitions.
for state, _ := range states {
if _, ok := transitions[state]; !ok {
return nil, fmt.Errorf("No transitions defined from state %q", state)
}
}
for _, toMap := range transitions {
for to, _ := range toMap {
delete(states, to)
}
}
for s, _ := range states {
if s != b.initialState {
return nil, fmt.Errorf("No transitions defined to state %q", s)
}
}
// Create and return the StateMachine.
sm := &StateMachine{
current: cachedState,
funcs: b.funcs,
transitions: transitions,
file: file,
}
// Write initial state back to file, in case it wasn't there before.
if err := ioutil.WriteFile(file, []byte(sm.Current()), os.ModePerm); err != nil {
return nil, err
}
return sm, nil
}
// StateMachine is a simple state machine implementation which persists its
// current state to a file.
type StateMachine struct {
current string
funcs map[string]TransitionFn
transitions map[string]map[string]string
file string
mtx sync.RWMutex
}
// Return the current state.
func (sm *StateMachine) Current() string {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
return sm.current
}
// Attempt to transition to the given state, using the transition function.
func (sm *StateMachine) Transition(dest string) error {
sm.mtx.Lock()
defer sm.mtx.Unlock()
toMap, ok := sm.transitions[sm.current]
if !ok {
return fmt.Errorf("No transitions defined from state %q", sm.current)
}
fName, ok := toMap[dest]
if !ok {
return fmt.Errorf("No transition defined from state %q to state %q", sm.current, dest)
}
fn, ok := sm.funcs[fName]
if !ok {
return fmt.Errorf("Undefined transition function %q", fName)
}
if err := fn(); err != nil {
return fmt.Errorf("Failed to transition from %q to %q: %s", sm.current, dest, err)
}
sm.current = dest
return ioutil.WriteFile(sm.file, []byte(sm.current), os.ModePerm)
}
// Return the name of the transition function from the current state to the
// given state.
func (sm *StateMachine) GetTransitionName(dest string) (string, error) {
sm.mtx.RLock()
defer sm.mtx.RUnlock()
toMap, ok := sm.transitions[sm.current]
if !ok {
return "", fmt.Errorf("No transitions defined from state %q", sm.current)
}
fName, ok := toMap[dest]
if !ok {
return "", fmt.Errorf("No transition defined from state %q to state %q", sm.current, dest)
}
return fName, nil
}