// Package pdag allows to define a DAG of processing fuctions.
//
// The processing can be triggered at any node of the graph
// and will follow the directed edges of the graph. Before
// a function in a node is executed all its parents
// have to finish execution. Thus the graph defines the order
// in which the functions are executed and which functions are
// executed serially or in parallel.
//
// A single shared object (of type interface{}) is passed
// to all functions. It's the responsibility of the functions
// to coordinate synchronized access to the object.
//
// If an error occurs in any functions all processing seizes
// and the error is returned by the Trigger function.
//
// For example
//
//     root := NewNodeWithParents(a)
//     NewNodeWithParents(d, root.Child(b), root.Child(c))
//
//     state := map[string]string{}
//     root.Trigger(data)
//
// defines a diamond-shaped DAG. Execution starts at the root
// and after function 'a', functions 'b' and 'c' will be executed
// in parallel. Once both have completed, function 'd' will be
// called. All will be passed the value of 'state'.
//
// An instance of processing DAG is thread-safe. Data consistency has
// to be ensured in the shared state.
package pdag

import (
	"context"
	"sync"

	"github.com/google/uuid"

	"go.skia.org/infra/go/sklog"
)

// ProcessFn is the type of the processing function for each node.
type ProcessFn func(ctx context.Context, state interface{}) error

// Node of the Dag.
type Node struct {
	id       string
	name     string
	children map[string]*Node
	procFn   ProcessFn
	mutex    sync.Mutex
	inputMap map[string]int
	verbose  bool // debugging only
}

// NoOp does nothing.
func NoOp(_ context.Context, _ interface{}) error {
	return nil
}

// NewNodeWithParents creates a new Node in the processing DAG. It takes the function
// to be executed in this node and an optional list of parent nodes.
func NewNodeWithParents(fn ProcessFn, parents ...*Node) *Node {
	// Create a new node with a unique id.
	id := uuid.New()
	node := &Node{
		id:       id.String(),
		name:     id.String(),
		children: map[string]*Node{},
		procFn:   fn,
		inputMap: map[string]int{},
	}

	// Link the children and parents.
	for _, parent := range parents {
		parent.children[node.id] = node
	}

	return node
}

// Child is a shorthand function that creates a child
// node of an existing node.
func (n *Node) Child(fn ProcessFn) *Node {
	return NewNodeWithParents(fn, n)
}

// Trigger starts execution at the current node and
// executes all functions that are descendents of this node.
// It blocks until all nodes have been executed. If any
// of the functions returns an error, execution ceases
// and the error is returned.
// Note: Trigger can be called on any node in the graph
// and will only call the descendants of that node.
func (n *Node) Trigger(ctx context.Context, state interface{}) error {
	// Create a call message.
	msg := call{
		id:    uuid.New().String(),
		state: state,
		errCh: make(chan error, 1),
	}

	// Mark all nodes with the number of inputs they should expect.
	nodesCalled := n.addInput(msg.id)
	msg.wg.Add(nodesCalled)

	if n.verbose {
		n.dump(msg.id, "")
		sklog.Infof("Number of nodes to call: %d\n", nodesCalled)
	}

	// Trigger the execution and wait for all nodes to be visited.
	n.process(ctx, &msg)
	msg.wg.Wait()

	if msg.hasErr() {
		return <-msg.errCh
	}

	return nil
}

// setName assigns a name to the Node. It's purely used
// for debugging purposes. Internally a unique id is used.
// it returns Node so it can easily be chained.
func (n *Node) setName(name string) *Node {
	n.name = name
	return n
}

// dump outputs the input connections of this node and its
// descendants. Only useful for debugging.
func (n *Node) dump(msgID, indent string) {
	sklog.Infof("Node %s : %d\n", n.name, n.inputMap[msgID])
	for _, child := range n.children {
		child.dump(msgID, indent+"     ")
	}
}

// addInput records the number of inputs each node
// has to expect and records them in inputMap and returns the
// number of descendants of this node (including the node itself).
func (n *Node) addInput(msgID string) int {
	descendants := 0
	n.mutex.Lock()
	if _, ok := n.inputMap[msgID]; !ok {
		descendants = 1
	}
	n.inputMap[msgID] += 1
	n.mutex.Unlock()

	// If we have visited this node before that means we have
	// visited its children and we can stop now.
	if descendants == 0 {
		return descendants
	}

	for _, child := range n.children {
		descendants += child.addInput(msgID)
	}
	return descendants
}

// process is the core processing function of this node that
// processes 'call' messages. When all inputs of a call are
// received it will trigger the function and pass the call
// message to the children of this node.
func (n *Node) process(ctx context.Context, msg *call) {
	// Check if the we have all inputs for this node.
	n.mutex.Lock()
	remaining := n.inputMap[msg.id]
	if remaining == 1 {
		delete(n.inputMap, msg.id)
	} else {
		defer n.mutex.Unlock()
		n.inputMap[msg.id]--
		return
	}
	n.mutex.Unlock()

	// We have all inputs now call the function for this node
	// and afterwards feed it to all the children.
	go func(msg *call) {
		// If the context was cancelled or there was an error, skip the function call.
		if err := ctx.Err(); err != nil {
			msg.setErr(err)
		}
		if !msg.hasErr() {
			if err := n.procFn(ctx, msg.state); err != nil {
				msg.setErr(err)
			}
		}
		msg.wg.Done()

		// Feed into all the children asynchronously.
		for _, child := range n.children {
			go func(child *Node) {
				child.process(ctx, msg)
			}(child)
		}
	}(msg)
}

// call is the message type that is passed between the nodes
// of the DAG.
type call struct {
	id    string
	state interface{}
	errCh chan error
	wg    sync.WaitGroup
}

// hasErr returns true if a error has been set on this call.
func (c *call) hasErr() bool {
	return len(c.errCh) > 0
}

// setErr sets the error on this call. If an error has already been
// set, it logs an error message.
func (c *call) setErr(err error) {
	// If the error channel is not ready to receive it means an error
	// has already been set.
	select {
	case c.errCh <- err:
	default:
		sklog.Errorf("Error channel already set on call. Error: %s", err)
	}
}
