Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions graph/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type nodeInfo struct {
dependencies int // Number of dependencies (predecessor count)
isFinish bool // Whether this is the finish node
hasConditions bool // Whether outgoing edges carry conditions
loopDependencies int // Number of incoming loop edges
loopEdgeSources map[string]bool // Precomputed: which parents connect via loop edges
}

// Executor represents a compiled graph ready for execution. It is safe for
Expand All @@ -27,9 +29,21 @@ func NewExecutor(g *Graph) *Executor {
// Build predecessors map for deterministic state aggregation
predecessors := make(map[string][]string, len(g.nodes))
dependencyCounts := make(map[string]int)
loopDependencyCounts := make(map[string]int)
loopEdgeSources := make(map[string]map[string]bool)

for from, edges := range g.edges {
for _, edge := range edges {
predecessors[edge.to] = append(predecessors[edge.to], from)
if edge.edgeType == EdgeTypeLoop {
loopDependencyCounts[edge.to]++
// Track which parents connect via loop edges
if loopEdgeSources[edge.to] == nil {
loopEdgeSources[edge.to] = make(map[string]bool)
}
loopEdgeSources[edge.to][from] = true
continue
}
dependencyCounts[edge.to]++
}
}
Expand Down Expand Up @@ -59,6 +73,8 @@ func NewExecutor(g *Graph) *Executor {
dependencies: dependencyCounts[nodeName],
isFinish: nodeName == g.finishPoint,
hasConditions: hasConditions,
loopDependencies: loopDependencyCounts[nodeName],
loopEdgeSources: loopEdgeSources[nodeName],
}
nodeInfos[nodeName] = node
}
Expand Down
108 changes: 91 additions & 17 deletions graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ func WithMiddleware(ms ...Middleware) Option {
// EdgeCondition is a function that determines if an edge should be followed based on the current state.
type EdgeCondition func(ctx context.Context, state State) bool

// EdgeType defines the type of an edge in the graph.
type EdgeType int

const (
// EdgeTypeNormal is a regular edge (default).
EdgeTypeNormal EdgeType = iota
// EdgeTypeLoop is a back edge that forms a loop - allows revisiting nodes.
EdgeTypeLoop
// EdgeTypeExit is an exit edge from a loop - required for loop nodes.
EdgeTypeExit
)

// EdgeOption configures an edge before it is added to the graph.
type EdgeOption func(*conditionalEdge)

Expand All @@ -36,10 +48,18 @@ func WithEdgeCondition(condition EdgeCondition) EdgeOption {
}
}

// conditionalEdge represents an edge with an optional condition.
// WithEdgeType sets the type of the edge (Normal, Loop, or Exit).
func WithEdgeType(edgeType EdgeType) EdgeOption {
return func(edge *conditionalEdge) {
edge.edgeType = edgeType
}
}

// conditionalEdge represents an edge with an optional condition and type.
type conditionalEdge struct {
to string
condition EdgeCondition // nil means always follow this edge
edgeType EdgeType // type of edge (normal, loop, or exit)
}

// Graph represents a directed graph of processing nodes. Cycles are allowed.
Expand Down Expand Up @@ -161,10 +181,53 @@ func (g *Graph) validateStructure() error {
return fmt.Errorf("graph: node '%s' has mixed conditional and unconditional edges", from)
}
}

// Validate loop/exit edge rules
if err := g.validateLoopEdges(); err != nil {
return err
}

return nil
}

// validateLoopEdges ensures that nodes with loop edges also have exit edges,
// and that loop/exit edges have proper conditions.
func (g *Graph) validateLoopEdges() error {
for from, edges := range g.edges {
var loopEdges []conditionalEdge
var exitEdges []conditionalEdge

for _, edge := range edges {
switch edge.edgeType {
case EdgeTypeLoop:
loopEdges = append(loopEdges, edge)
case EdgeTypeExit:
exitEdges = append(exitEdges, edge)
}
}

// If node has loop edges, it must have exit edges
if len(loopEdges) > 0 && len(exitEdges) == 0 {
return fmt.Errorf("graph: node '%s' has loop edges but no exit edges", from)
}

// All loop and exit edges must have conditions
for _, edge := range loopEdges {
if edge.condition == nil {
return fmt.Errorf("graph: loop edge from '%s' to '%s' must have a condition", from, edge.to)
}
}
for _, edge := range exitEdges {
if edge.condition == nil {
return fmt.Errorf("graph: exit edge from '%s' to '%s' must have a condition", from, edge.to)
}
}
}
return nil
}

// ensureReachable verifies that the finish node can be reached from the entry node.
// Loop edges are skipped during reachability check since they don't advance toward the finish.
func (g *Graph) ensureReachable() error {
if g.entryPoint == g.finishPoint {
return nil
Expand All @@ -182,13 +245,18 @@ func (g *Graph) ensureReachable() error {
return nil
}
for _, edge := range g.edges[node] {
// Skip loop edges during reachability check
if edge.edgeType == EdgeTypeLoop {
continue
}
queue = append(queue, edge.to)
}
}
return fmt.Errorf("graph: finish node not reachable: %s", g.finishPoint)
}

// ensureAcyclic verifies that the graph does not contain directed cycles.
// ensureAcyclic verifies that the graph does not contain directed cycles,
// unless the cycle includes at least one edge marked as EdgeTypeLoop.
func (g *Graph) ensureAcyclic() error {
const (
stateUnvisited = iota
Expand All @@ -198,26 +266,32 @@ func (g *Graph) ensureAcyclic() error {
states := make(map[string]int, len(g.nodes))
stack := make([]string, 0, len(g.nodes))

var visit func(string) error
visit = func(node string) error {
var visit func(string, bool) error
visit = func(node string, hasLoopEdge bool) error {
states[node] = stateVisiting
stack = append(stack, node)

for _, edge := range g.edges[node] {
next := edge.to
nextHasLoop := hasLoopEdge || edge.edgeType == EdgeTypeLoop
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hasLoopEdge tracking logic is incorrect. hasLoopEdge should only be true if there's a loop edge in the path from the root to the current node, not just anywhere in the traversal. The current implementation propagates hasLoopEdge along all paths once any loop edge is encountered, which means even paths that don't contain a loop edge will be marked as having one.

This could allow cycles that don't actually contain a loop edge to pass validation. The fix should track whether the current cycle path contains a loop edge, not whether any path from the root contained one.

Copilot uses AI. Check for mistakes.
switch states[next] {
case stateVisiting:
cycleStart := 0
for i, name := range stack {
if name == next {
cycleStart = i
break
// Found a back edge (cycle)
// Allow it only if the cycle includes a loop edge
if !nextHasLoop {
cycleStart := 0
for i, name := range stack {
if name == next {
cycleStart = i
break
}
}
cycle := append(append([]string{}, stack[cycleStart:]...), next)
Comment on lines +282 to +291
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cycle detection logic is flawed. When a back edge is found (line 278), the code checks !nextHasLoop to determine if the cycle is invalid. However, nextHasLoop includes hasLoopEdge from the parameter, which tracks whether ANY loop edge was seen on the path from the DFS root to the current node.

The correct check should only validate whether the cycle itself (from next back to next) contains a loop edge. Consider this graph: A -> B (loop edge), B -> C, C -> A. When visiting C->A, nextHasLoop would be true because B->C had hasLoopEdge=true, even though the cycle C->A doesn't contain a loop edge.

Suggested change
// Allow it only if the cycle includes a loop edge
if !nextHasLoop {
cycleStart := 0
for i, name := range stack {
if name == next {
cycleStart = i
break
}
}
cycle := append(append([]string{}, stack[cycleStart:]...), next)
// Allow it only if the cycle itself includes a loop edge
cycleStart := 0
for i, name := range stack {
if name == next {
cycleStart = i
break
}
}
cycle := append(append([]string{}, stack[cycleStart:]...), next)
hasLoop := false
// Check each edge in the cycle for EdgeTypeLoop
for i := 0; i < len(cycle)-1; i++ {
from := cycle[i]
to := cycle[i+1]
for _, e := range g.edges[from] {
if e.to == to && e.edgeType == EdgeTypeLoop {
hasLoop = true
break
}
}
if hasLoop {
break
}
}
if !hasLoop {

Copilot uses AI. Check for mistakes.
return fmt.Errorf("graph: cycle detected but edge not marked as EdgeTypeLoop (cycle: %s)", strings.Join(cycle, " -> "))
}
cycle := append(append([]string{}, stack[cycleStart:]...), next)
return fmt.Errorf("graph: cycles are not supported (cycle: %s)", strings.Join(cycle, " -> "))
// Loop edge is allowed - continue checking other edges
case stateUnvisited:
if err := visit(next); err != nil {
if err := visit(next, nextHasLoop); err != nil {
return err
}
}
Expand All @@ -230,7 +304,7 @@ func (g *Graph) ensureAcyclic() error {

for name := range g.nodes {
if states[name] == stateUnvisited {
if err := visit(name); err != nil {
if err := visit(name, false); err != nil {
return err
}
}
Expand All @@ -250,13 +324,13 @@ func (g *Graph) Compile() (*Executor, error) {
if err := g.ensureAcyclic(); err != nil {
return nil, err
}
// Check reachability
if err := g.ensureReachable(); err != nil {
return nil, err
}
// Final structural validations
if err := g.validateStructure(); err != nil {
return nil, err
}
// Check reachability once structure is validated
if err := g.ensureReachable(); err != nil {
return nil, err
}
return NewExecutor(g), nil
}
Loading
Loading