| // Package pdag allows to define a DAG of processing fuctions. |
| // |
| // The processing can be triggered at any node of the graph |
| // and will follow the directed edges of the graph. Before |
| // a function in a node is executed all its parents |
| // have to finish execution. Thus the graph defines the order |
| // in which the functions are executed and which functions are |
| // executed serially or in parallel. |
| // |
| // A single shared object (of type interface{}) is passed |
| // to all functions. It's the responsibility of the functions |
| // to coordinate synchronized access to the object. |
| // |
| // If an error occurs in any functions all processing seizes |
| // and the error is returned by the Trigger function. |
| // |
| // For example |
| // |
| // root := NewNodeWithParents(a) |
| // NewNodeWithParents(d, root.Child(b), root.Child(c)) |
| // |
| // state := map[string]string{} |
| // root.Trigger(data) |
| // |
| // defines a diamond-shaped DAG. Execution starts at the root |
| // and after function 'a', functions 'b' and 'c' will be executed |
| // in parallel. Once both have completed, function 'd' will be |
| // called. All will be passed the value of 'state'. |
| // |
| // An instance of processing DAG is thread-safe. Data consistency has |
| // to be ensured in the shared state. |
| package pdag |
| |
| import ( |
| "context" |
| "sync" |
| |
| "github.com/google/uuid" |
| |
| "go.skia.org/infra/go/sklog" |
| ) |
| |
| // ProcessFn is the type of the processing function for each node. |
| type ProcessFn func(ctx context.Context, state interface{}) error |
| |
| // Node of the Dag. |
| type Node struct { |
| id string |
| name string |
| children map[string]*Node |
| procFn ProcessFn |
| mutex sync.Mutex |
| inputMap map[string]int |
| verbose bool // debugging only |
| } |
| |
| // NoOp does nothing. |
| func NoOp(_ context.Context, _ interface{}) error { |
| return nil |
| } |
| |
| // NewNodeWithParents creates a new Node in the processing DAG. It takes the function |
| // to be executed in this node and an optional list of parent nodes. |
| func NewNodeWithParents(fn ProcessFn, parents ...*Node) *Node { |
| // Create a new node with a unique id. |
| id := uuid.New() |
| node := &Node{ |
| id: id.String(), |
| name: id.String(), |
| children: map[string]*Node{}, |
| procFn: fn, |
| inputMap: map[string]int{}, |
| } |
| |
| // Link the children and parents. |
| for _, parent := range parents { |
| parent.children[node.id] = node |
| } |
| |
| return node |
| } |
| |
| // Child is a shorthand function that creates a child |
| // node of an existing node. |
| func (n *Node) Child(fn ProcessFn) *Node { |
| return NewNodeWithParents(fn, n) |
| } |
| |
| // Trigger starts execution at the current node and |
| // executes all functions that are descendents of this node. |
| // It blocks until all nodes have been executed. If any |
| // of the functions returns an error, execution ceases |
| // and the error is returned. |
| // Note: Trigger can be called on any node in the graph |
| // and will only call the descendants of that node. |
| func (n *Node) Trigger(ctx context.Context, state interface{}) error { |
| // Create a call message. |
| msg := call{ |
| id: uuid.New().String(), |
| state: state, |
| errCh: make(chan error, 1), |
| } |
| |
| // Mark all nodes with the number of inputs they should expect. |
| nodesCalled := n.addInput(msg.id) |
| msg.wg.Add(nodesCalled) |
| |
| if n.verbose { |
| n.dump(msg.id, "") |
| sklog.Infof("Number of nodes to call: %d\n", nodesCalled) |
| } |
| |
| // Trigger the execution and wait for all nodes to be visited. |
| n.process(ctx, &msg) |
| msg.wg.Wait() |
| |
| if msg.hasErr() { |
| return <-msg.errCh |
| } |
| |
| return nil |
| } |
| |
| // setName assigns a name to the Node. It's purely used |
| // for debugging purposes. Internally a unique id is used. |
| // it returns Node so it can easily be chained. |
| func (n *Node) setName(name string) *Node { |
| n.name = name |
| return n |
| } |
| |
| // dump outputs the input connections of this node and its |
| // descendants. Only useful for debugging. |
| func (n *Node) dump(msgID, indent string) { |
| sklog.Infof("Node %s : %d\n", n.name, n.inputMap[msgID]) |
| for _, child := range n.children { |
| child.dump(msgID, indent+" ") |
| } |
| } |
| |
| // addInput records the number of inputs each node |
| // has to expect and records them in inputMap and returns the |
| // number of descendants of this node (including the node itself). |
| func (n *Node) addInput(msgID string) int { |
| descendants := 0 |
| n.mutex.Lock() |
| if _, ok := n.inputMap[msgID]; !ok { |
| descendants = 1 |
| } |
| n.inputMap[msgID] += 1 |
| n.mutex.Unlock() |
| |
| // If we have visited this node before that means we have |
| // visited its children and we can stop now. |
| if descendants == 0 { |
| return descendants |
| } |
| |
| for _, child := range n.children { |
| descendants += child.addInput(msgID) |
| } |
| return descendants |
| } |
| |
| // process is the core processing function of this node that |
| // processes 'call' messages. When all inputs of a call are |
| // received it will trigger the function and pass the call |
| // message to the children of this node. |
| func (n *Node) process(ctx context.Context, msg *call) { |
| // Check if the we have all inputs for this node. |
| n.mutex.Lock() |
| remaining := n.inputMap[msg.id] |
| if remaining == 1 { |
| delete(n.inputMap, msg.id) |
| } else { |
| defer n.mutex.Unlock() |
| n.inputMap[msg.id]-- |
| return |
| } |
| n.mutex.Unlock() |
| |
| // We have all inputs now call the function for this node |
| // and afterwards feed it to all the children. |
| go func(msg *call) { |
| // If the context was cancelled or there was an error, skip the function call. |
| if err := ctx.Err(); err != nil { |
| msg.setErr(err) |
| } |
| if !msg.hasErr() { |
| if err := n.procFn(ctx, msg.state); err != nil { |
| msg.setErr(err) |
| } |
| } |
| msg.wg.Done() |
| |
| // Feed into all the children asynchronously. |
| for _, child := range n.children { |
| go func(child *Node) { |
| child.process(ctx, msg) |
| }(child) |
| } |
| }(msg) |
| } |
| |
| // call is the message type that is passed between the nodes |
| // of the DAG. |
| type call struct { |
| id string |
| state interface{} |
| errCh chan error |
| wg sync.WaitGroup |
| } |
| |
| // hasErr returns true if a error has been set on this call. |
| func (c *call) hasErr() bool { |
| return len(c.errCh) > 0 |
| } |
| |
| // setErr sets the error on this call. If an error has already been |
| // set, it logs an error message. |
| func (c *call) setErr(err error) { |
| // If the error channel is not ready to receive it means an error |
| // has already been set. |
| select { |
| case c.errCh <- err: |
| default: |
| sklog.Errorf("Error channel already set on call. Error: %s", err) |
| } |
| } |