Coverage for src / graphable / graph.py: 96%
536 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-16 21:32 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-16 21:32 +0000
1from __future__ import annotations
3from graphlib import CycleError, TopologicalSorter
4from hashlib import blake2b
5from logging import getLogger
6from pathlib import Path
7from typing import Any, Callable, Iterator
9from .enums import Direction, Engine
10from .errors import GraphConsistencyError, GraphCycleError
11from .graphable import Graphable
13logger = getLogger(__name__)
16class Graph[T: Graphable[Any]]:
17 """
18 Represents a graph of Graphable nodes.
19 """
21 def __init__(self, initial: set[T] | list[T] | None = None, discover: bool = False):
22 """
23 Initialize a Graph.
25 Args:
26 initial (set[T] | list[T] | None): An optional set of initial nodes.
27 discover (bool): If True, automatically expand the graph to include all
28 reachable ancestors and descendants of the initial nodes.
30 Raises:
31 GraphCycleError: If the initial set of nodes contains a cycle.
32 """
33 self._nodes: set[T] = set()
34 self._topological_order: list[T] | None = None
35 self._parallel_topological_order: list[set[T]] | None = None
36 self._checksum: str | None = None
38 if initial:
39 for node in initial:
40 self.add_node(node)
42 if discover:
43 self.discover()
45 self.check_consistency()
46 self.check_cycles()
48 def clone(self, include_edges: bool = False) -> Graph[T]:
49 """
50 Create a copy of this graph.
52 Args:
53 include_edges: If True, the new graph will have the same edges as this one.
54 If False, the new graph will contain copies of all nodes but with
55 no edges between them.
57 Returns:
58 Graph[T]: A new Graph instance.
59 """
60 import copy
62 logger.debug(f"Cloning graph (include_edges={include_edges}).")
63 node_map: dict[T, T] = {}
64 for node in self._nodes:
65 new_node = copy.copy(node)
66 # Reset internal edge tracking
67 new_node._dependents = {}
68 new_node._depends_on = {}
69 # Manually clone tags to avoid shared state
70 new_node._tags = set(node.tags)
71 node_map[node] = new_node
73 new_graph = Graph(set(node_map.values()))
75 if include_edges:
76 for u in self._nodes:
77 for v, attrs in self.neighbors(u, Direction.DOWN):
78 new_graph.add_edge(node_map[u], node_map[v], **attrs)
80 return new_graph
82 def neighbors(
83 self, node: T, direction: Direction = Direction.DOWN
84 ) -> Iterator[tuple[T, dict[str, Any]]]:
85 """
86 Iterate over neighbors within this graph.
88 Args:
89 node (T): The source node.
90 direction: Direction.DOWN for dependents, Direction.UP for dependencies.
92 Yields:
93 tuple[T, dict[str, Any]]: A (neighbor_node, edge_attributes) tuple.
94 """
95 neighbors = node.dependents if direction == Direction.DOWN else node.depends_on
96 for neighbor in neighbors:
97 if neighbor in self._nodes:
98 attrs = (
99 node.edge_attributes(neighbor)
100 if direction == Direction.DOWN
101 else neighbor.edge_attributes(node)
102 )
103 yield neighbor, attrs
105 def internal_dependents(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]:
106 """Alias for neighbors(node, Direction.DOWN)."""
107 return self.neighbors(node, Direction.DOWN)
109 def internal_depends_on(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]:
110 """Alias for neighbors(node, Direction.UP)."""
111 return self.neighbors(node, Direction.UP)
113 def discover(self) -> None:
114 """
115 Traverse the entire connectivity of the current nodes and add any
116 reachable ancestors or descendants that are not yet members.
117 """
118 logger.debug(f"Discovering reachable nodes from {len(self._nodes)} base nodes.")
120 new_nodes: set[T] = set()
121 for node in list(self._nodes):
122 new_nodes.update(
123 self._traverse(node, direction=Direction.UP, limit_to_graph=False)
124 )
125 new_nodes.update(
126 self._traverse(node, direction=Direction.DOWN, limit_to_graph=False)
127 )
129 for node in new_nodes:
130 self.add_node(node)
132 def _invalidate_cache(self) -> None:
133 """Clear all cached calculations for this graph."""
134 logger.debug("Invalidating graph cache.")
135 self._topological_order = None
136 self._parallel_topological_order = None
137 self._checksum = None
139 def __contains__(self, item: object) -> bool:
140 """
141 Check if a node or its reference is in the graph.
143 Args:
144 item (object): Either a Graphable node or a reference object.
146 Returns:
147 bool: True if present, False otherwise.
148 """
149 if isinstance(item, Graphable):
150 return item in self._nodes
152 return any(node.reference == item for node in self._nodes)
154 def __getitem__(self, reference: Any) -> T:
155 """
156 Get a node by its reference.
158 Args:
159 reference (Any): The reference object to search for.
161 Returns:
162 T: The Graphable node.
164 Raises:
165 KeyError: If no node with the given reference exists.
166 """
167 for node in self._nodes:
168 if node.reference == reference:
169 return node
170 raise KeyError(f"No node found with reference: {reference}")
172 def __iter__(self):
173 """
174 Iterate over nodes in topological order.
175 """
176 return iter(self.topological_order())
178 def __len__(self) -> int:
179 """
180 Get the number of nodes in the graph.
181 """
182 return len(self._nodes)
184 def is_equal_to(self, other: object) -> bool:
185 """
186 Check if this graph is equal to another graph.
187 Equality is defined as having the same checksum (structural and metadata-wise).
189 Args:
190 other: The other object to compare with.
192 Returns:
193 bool: True if equal, False otherwise.
194 """
195 if not isinstance(other, Graph):
196 return False
198 return self.checksum() == other.checksum()
200 def checksum(self) -> str:
201 """
202 Calculate a deterministic BLAKE2b checksum of the graph.
203 The checksum accounts for all member nodes (references, tags, duration, status)
204 and edges (including attributes) between them. External nodes are excluded.
206 Returns:
207 str: The hexadecimal digest of the graph.
208 """
209 if self._checksum is not None:
210 return self._checksum
212 # 1. Sort nodes by reference to ensure deterministic iteration
213 sorted_nodes = sorted(self._nodes, key=lambda n: str(n.reference))
215 hasher = blake2b()
217 for node in sorted_nodes:
218 # 2. Add node reference, duration, and status
219 hasher.update(str(node.reference).encode())
220 hasher.update(f":duration:{node.duration}".encode())
221 hasher.update(f":status:{node.status}".encode())
223 # 3. Add sorted tags
224 for tag in sorted(node.tags):
225 hasher.update(f":tag:{tag}".encode())
227 # 4. Add sorted dependents (edges) with attributes - Only those in the graph
228 internal_dependents = sorted(
229 [d for d in node.dependents if d in self._nodes],
230 key=lambda n: str(n.reference),
231 )
232 for dep in internal_dependents:
233 hasher.update(f":edge:{dep.reference}".encode())
234 # Add edge attributes deterministically
235 attrs = node.edge_attributes(dep)
236 for key in sorted(attrs.keys()):
237 hasher.update(f":attr:{key}:{attrs[key]}".encode())
239 self._checksum = hasher.hexdigest()
240 return self._checksum
242 def validate_checksum(self, expected: str) -> bool:
243 """
244 Validate the graph against an expected checksum.
246 Args:
247 expected (str): The expected BLAKE2b hexadecimal digest.
249 Returns:
250 bool: True if the checksums match, False otherwise.
251 """
252 return self.checksum() == expected
254 def write_checksum(self, path: Path | str) -> None:
255 """
256 Write the graph's current checksum to a file.
258 Args:
259 path: Path to the output checksum file.
260 """
261 p = Path(path)
262 digest = self.checksum()
263 logger.info(f"Writing checksum to: {p}")
264 with open(p, "w+") as f:
265 f.write(digest)
267 @staticmethod
268 def read_checksum(path: Path | str) -> str:
269 """
270 Read a checksum from a file.
272 Args:
273 path: Path to the checksum file.
275 Returns:
276 str: The checksum string.
277 """
278 p = Path(path)
279 logger.debug(f"Reading checksum from: {p}")
280 with open(p, "r") as f:
281 return f.read().strip()
283 @classmethod
284 def read(cls, path: Path | str, **kwargs: Any) -> Graph[Any]:
285 """Read a graph from a file, automatically detecting the format."""
286 from .parsers.utils import extract_checksum
287 from .registry import PARSERS
289 p = Path(path)
290 ext = p.suffix.lower()
291 parser = PARSERS.get(ext)
292 if not parser:
293 raise ValueError(f"Unsupported extension for reading: {ext}")
295 g = parser(p, **kwargs)
297 if embedded := extract_checksum(p):
298 if not g.validate_checksum(embedded):
299 raise ValueError(f"Checksum validation failed for {p}")
300 return g
302 def write(
303 self,
304 path: Path | str,
305 transitive_reduction: bool = False,
306 embed_checksum: bool = False,
307 engine: Engine | str | None = None,
308 **kwargs: Any,
309 ) -> None:
310 """
311 Write the graph to a file, automatically detecting the format.
313 Args:
314 path: Path to the output file.
315 transitive_reduction: If True, perform transitive reduction before writing.
316 embed_checksum: If True, embed a BLAKE2b checksum in the output.
317 engine: The rendering engine to use for images (.svg, .png).
318 If None, it will be auto-detected.
319 **kwargs: Additional arguments passed to the specific exporter.
320 """
321 from .registry import EXPORTERS
323 p = Path(path)
324 ext = p.suffix.lower()
326 # Handle images specifically to allow engine selection/auto-detection
327 if ext in (".svg", ".png"):
328 from .views.utils import get_image_exporter
330 exporter = get_image_exporter(engine)
331 else:
332 exporter = EXPORTERS.get(ext)
334 if not exporter:
335 raise ValueError(f"Unsupported extension: {ext}")
337 return self.export(exporter, p, transitive_reduction, embed_checksum, **kwargs)
339 def parallelized_topological_order(self) -> list[set[T]]:
340 """
341 Get the nodes in topological order, grouped into sets that can be processed in parallel.
342 Only nodes that are members of this graph are included.
344 Returns:
345 list[set[T]]: A list of sets of member nodes that have no unmet dependencies.
346 """
347 if self._parallel_topological_order is None:
348 logger.debug("Calculating parallel topological order.")
349 self._parallel_topological_order = []
350 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes})
351 sorter.prepare()
352 while sorter.is_active():
353 ready = sorter.get_ready()
354 if not ready:
355 break
356 # Filter to only include nodes that are actually in this graph
357 filtered_ready = {node for node in ready if node in self._nodes}
358 if filtered_ready:
359 self._parallel_topological_order.append(filtered_ready)
360 sorter.done(*ready)
362 return self._parallel_topological_order
364 def subgraph_between(self, source: T, target: T) -> Graph[T]:
365 """
366 Create a new graph containing all nodes and edges on all paths between source and target.
368 Args:
369 source (T): The starting node.
370 target (T): The ending node.
372 Returns:
373 Graph[T]: A new Graph instance.
374 """
375 if source not in self._nodes or target not in self._nodes:
376 raise KeyError("Both source and target must be in the graph.")
378 # Nodes between U and V are nodes that are descendants of U AND ancestors of V
379 descendants = {source} | set(self.descendants(source))
380 ancestors = {target} | set(self.ancestors(target))
381 between = descendants & ancestors
383 return Graph(between)
385 def diff_graph(self, other: Graph[T]) -> Graph[T]:
386 """
387 Create a visualization-friendly diff graph.
388 - Nodes in both: grey/default
389 - Added nodes: green
390 - Removed nodes: red
391 - Modified nodes/edges: yellow/orange
393 Returns:
394 Graph[T]: A merged graph with diff metadata.
395 """
396 import copy
398 merged_nodes_map: dict[Any, T] = {}
399 diff_info = self.diff(other)
401 def get_or_create(node: T, status: str) -> T:
402 ref = node.reference
403 if ref not in merged_nodes_map:
404 new_node = copy.copy(node)
405 new_node._dependents = {}
406 new_node._depends_on = {}
407 new_node.add_tag(f"diff:{status}")
408 # Add visual hints
409 color = {"added": "green", "removed": "red", "modified": "orange"}.get(
410 status, "grey"
411 )
412 new_node.add_tag(f"color:{color}")
413 merged_nodes_map[ref] = new_node
414 return merged_nodes_map[ref]
416 # Add all nodes from both
417 for node in self._nodes:
418 status = (
419 "removed"
420 if node.reference in diff_info["removed_nodes"]
421 else "unchanged"
422 )
423 if node.reference in diff_info["modified_nodes"]:
424 status = "modified"
425 get_or_create(node, status)
427 for node in other._nodes:
428 status = (
429 "added" if node.reference in diff_info["added_nodes"] else "unchanged"
430 )
431 if node.reference in diff_info["modified_nodes"]:
432 status = "modified"
433 get_or_create(node, status)
435 new_graph = Graph(set(merged_nodes_map.values()))
437 # Add edges from self (original)
438 for u in self._nodes:
439 for v in u.dependents:
440 if v not in self._nodes:
441 continue
442 edge = (u.reference, v.reference)
443 if edge in diff_info["removed_edges"]:
444 new_graph.add_edge(
445 merged_nodes_map[u.reference],
446 merged_nodes_map[v.reference],
447 diff_status="removed",
448 color="red",
449 )
450 elif edge in diff_info["modified_edges"]:
451 new_graph.add_edge(
452 merged_nodes_map[u.reference],
453 merged_nodes_map[v.reference],
454 **u.edge_attributes(v),
455 diff_status="modified",
456 color="orange",
457 )
458 else:
459 new_graph.add_edge(
460 merged_nodes_map[u.reference],
461 merged_nodes_map[v.reference],
462 **u.edge_attributes(v),
463 )
465 # Add edges from other (new)
466 for u in other._nodes:
467 for v in u.dependents:
468 if v not in other._nodes:
469 continue
470 edge = (u.reference, v.reference)
471 if edge in diff_info["added_edges"]:
472 new_graph.add_edge(
473 merged_nodes_map[u.reference],
474 merged_nodes_map[v.reference],
475 **u.edge_attributes(v),
476 diff_status="added",
477 color="green",
478 )
480 return new_graph
482 def transitive_closure(self) -> Graph[T]:
483 """
484 Compute the transitive closure of this graph.
485 An edge (u, v) exists in the transitive closure if there is a path from u to v.
487 Returns:
488 Graph[T]: A new Graph instance representing the transitive closure.
489 """
490 import copy
492 logger.debug("Calculating transitive closure.")
493 node_map = {node: copy.copy(node) for node in self._nodes}
494 for n in node_map.values():
495 n._dependents = {}
496 n._depends_on = {}
498 new_graph = Graph(set(node_map.values()))
499 for u in self._nodes:
500 for v in self.descendants(u):
501 new_graph.add_edge(node_map[u], node_map[v])
503 return new_graph
505 def suggest_cycle_breaks(self) -> list[tuple[T, T]]:
506 """
507 Identify a minimal set of edges to remove to make the graph a Directed Acyclic Graph (DAG).
508 Uses a greedy heuristic.
510 Returns:
511 list[tuple[T, T]]: A list of (source, target) tuples representing suggested edges to remove.
512 """
513 logger.debug("Suggesting cycle breaks.")
514 # Simple heuristic:
515 # 1. Take all nodes.
516 # 2. Try to order them such that we maximize forward edges.
517 # A simple way is to use the order they were added or any arbitrary order
518 # and see which edges go 'backwards'.
520 nodes = list(self._nodes)
521 # We can try to be slightly smarter by using a DFS and finding back-edges
522 back_edges = []
523 visited = set()
524 stack = set()
526 def dfs(u):
527 visited.add(u)
528 stack.add(u)
529 for v in u.dependents:
530 if v not in self._nodes:
531 continue
532 if v in stack:
533 back_edges.append((u, v))
534 elif v not in visited:
535 dfs(v)
536 stack.remove(u)
538 for node in nodes:
539 if node not in visited:
540 dfs(node)
542 return back_edges
544 def parallelized_topological_order_filtered(
545 self, fn: Callable[[T], bool]
546 ) -> list[set[T]]:
547 """
548 Get a filtered list of nodes in parallelized topological order.
550 Args:
551 fn (Callable[[T], bool]): The predicate function.
553 Returns:
554 list[set[T]]: Filtered sets of nodes for parallel processing.
555 """
556 result = []
557 for group in self.parallelized_topological_order():
558 filtered_group = {node for node in group if fn(node)}
559 if filtered_group:
560 result.append(filtered_group)
561 return result
563 def parallelized_topological_order_tagged(self, tag: str) -> list[set[T]]:
564 """
565 Get a list of nodes with a specific tag in parallelized topological order.
567 Args:
568 tag (str): The tag to filter by.
570 Returns:
571 list[set[T]]: Tagged sets of nodes for parallel processing.
572 """
573 return self.parallelized_topological_order_filtered(lambda n: n.is_tagged(tag))
575 def __eq__(self, other: object) -> bool:
576 """
577 Compare two graphs for equality.
578 """
579 return self.is_equal_to(other)
581 def __hash__(self) -> int:
582 """
583 Graphs are hashable by identity to allow them to be used in WeakSets
584 (e.g., as observers of Graphable nodes).
585 """
586 return id(self)
588 def check_cycles(self) -> None:
589 """
590 Check for cycles in the graph.
592 Raises:
593 GraphCycleError: If a cycle is detected.
594 """
595 try:
596 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes})
597 sorter.prepare()
598 except CycleError as e:
599 # graphlib.CycleError args: (message, cycle_tuple)
600 cycle = list(e.args[1]) if len(e.args) > 1 else None
601 raise GraphCycleError(f"Cycle detected in graph: {e}", cycle=cycle) from e
603 def check_consistency(self) -> None:
604 """
605 Check for consistency between depends_on and dependents for all nodes in the graph.
607 Raises:
608 GraphConsistencyError: If an inconsistency is detected.
609 """
610 for node in self._nodes:
611 self._check_node_consistency(node)
613 def _check_node_consistency(self, node: T) -> None:
614 """
615 Check for consistency between depends_on and dependents for a single node.
617 Args:
618 node (T): The node to check.
620 Raises:
621 GraphConsistencyError: If an inconsistency is detected.
622 """
623 # Check dependencies: if node depends on X, X must have node as dependent
624 for dep in node.depends_on:
625 if node not in dep.dependents:
626 raise GraphConsistencyError(
627 f"Inconsistency: Node '{node.reference}' depends on '{dep.reference}', "
628 f"but '{dep.reference}' does not list '{node.reference}' as a dependent."
629 )
630 # Check dependents: if node has dependent Y, Y must depend on node
631 for sub in node.dependents:
632 if node not in sub.depends_on:
633 raise GraphConsistencyError(
634 f"Inconsistency: Node '{node.reference}' has dependent '{sub.reference}', "
635 f"but '{sub.reference}' does not depend on '{node.reference}'."
636 )
638 def add_edge(self, node: T, dependent: T, **attributes: Any) -> None:
639 """
640 Add a directed edge from node to dependent.
641 Also adds the nodes to the graph if they are not already present.
643 Args:
644 node (T): The source node (dependency).
645 dependent (T): The target node (dependent).
646 **attributes: Edge attributes (e.g., weight, label).
648 Raises:
649 GraphCycleError: If adding the edge would create a cycle.
650 """
651 if node == dependent:
652 raise GraphCycleError(
653 f"Self-loop detected: node '{node.reference}' cannot depend on itself.",
654 cycle=[node, node],
655 )
657 # Check if adding this edge creates a cycle.
658 # A cycle is created if there is already a path from 'dependent' to 'node'.
659 if path := dependent.find_path(node):
660 cycle = path + [dependent]
661 raise GraphCycleError(
662 f"Adding edge '{node.reference}' -> '{dependent.reference}' would create a cycle.",
663 cycle=cycle,
664 )
666 self.add_node(node)
667 self.add_node(dependent)
669 node._add_dependent(dependent, **attributes)
670 dependent._add_depends_on(node, **attributes)
671 logger.debug(
672 f"Added edge: {node.reference} -> {dependent.reference} with attributes {attributes}"
673 )
675 # Invalidate cache
676 self._invalidate_cache()
678 def add_node(self, node: T) -> bool:
679 """
680 Add a node to the graph.
682 Args:
683 node (T): The node to add.
685 Returns:
686 bool: True if the node was added (was not already present), False otherwise.
688 Raises:
689 GraphCycleError: If the node is part of an existing cycle.
690 """
691 if node in self._nodes:
692 return False
694 # If the node is already part of a cycle (linked externally), adding it might be invalid
695 # if we want to enforce DAG.
696 if cycle := node.find_path(node):
697 raise GraphCycleError(
698 f"Node '{node.reference}' is part of an existing cycle.", cycle=cycle
699 )
701 self._check_node_consistency(node)
702 self._nodes.add(node)
703 node._register_observer(self)
704 logger.debug(f"Added node: {node.reference}")
706 self._invalidate_cache()
708 return True
710 def remove_edge(self, node: T, dependent: T) -> None:
711 """
712 Remove a directed edge from node to dependent.
714 Args:
715 node (T): The source node.
716 dependent (T): The target node.
717 """
718 if node in self._nodes and dependent in self._nodes:
719 node._remove_dependent(dependent)
720 dependent._remove_depends_on(node)
721 logger.debug(f"Removed edge: {node.reference} -> {dependent.reference}")
723 self._invalidate_cache()
725 def remove_node(self, node: T) -> None:
726 """
727 Remove a node and all its connected edges from the graph.
729 Args:
730 node (T): The node to remove.
731 """
732 if node in self._nodes:
733 # Remove from all nodes it depends on
734 for dep in list(node.depends_on):
735 dep._remove_dependent(node)
737 # Remove from all nodes that depend on it
738 for sub in list(node.dependents):
739 sub._remove_depends_on(node)
741 self._nodes.remove(node)
742 node._unregister_observer(self)
743 logger.debug(f"Removed node: {node.reference}")
745 self._invalidate_cache()
747 def ancestors(self, node: T) -> Iterator[T]:
748 """
749 Get an iterator for all nodes that the given node depends on, recursively.
751 Args:
752 node (T): The starting node.
754 Yields:
755 T: The next ancestor node.
756 """
757 return self._traverse(node, direction=Direction.UP, include_start=False)
759 def descendants(self, node: T) -> Iterator[T]:
760 """
761 Get an iterator for all nodes that depend on the given node, recursively.
763 Args:
764 node (T): The starting node.
766 Yields:
767 T: The next descendant node.
768 """
769 return self._traverse(node, direction=Direction.DOWN, include_start=False)
771 def bfs(
772 self,
773 start_node: T,
774 direction: Direction = Direction.DOWN,
775 limit_to_graph: bool = True,
776 ) -> Iterator[T]:
777 """
778 Perform a breadth-first search (BFS) starting from the given node.
780 Args:
781 start_node (T): The node to start from.
782 direction: Direction.UP for dependencies, Direction.DOWN for dependents.
783 limit_to_graph: If True, only return nodes that are members of this graph.
785 Yields:
786 T: Each reached node in breadth-first order.
787 """
788 from collections import deque
790 if limit_to_graph and start_node not in self._nodes:
791 return
793 visited: set[T] = {start_node}
794 queue: deque[T] = deque([start_node])
796 yield start_node
798 while queue:
799 current = queue.popleft()
800 neighbors = (
801 current.dependents
802 if direction == Direction.DOWN
803 else current.depends_on
804 )
805 for neighbor in neighbors:
806 if neighbor not in visited:
807 if limit_to_graph and neighbor not in self._nodes:
808 continue
809 visited.add(neighbor)
810 yield neighbor
811 queue.append(neighbor)
813 def dfs(
814 self,
815 start_node: T,
816 direction: Direction = Direction.DOWN,
817 limit_to_graph: bool = True,
818 ) -> Iterator[T]:
819 """
820 Perform a depth-first search (DFS) starting from the given node.
822 Args:
823 start_node (T): The node to start from.
824 direction: Direction.UP for dependencies, Direction.DOWN for dependents.
825 limit_to_graph: If True, only return nodes that are members of this graph.
827 Yields:
828 T: Each reached node in depth-first order.
829 """
830 return self._traverse(
831 start_node,
832 direction=direction,
833 limit_to_graph=limit_to_graph,
834 include_start=True,
835 )
837 def _traverse(
838 self,
839 start_node: T,
840 direction: Direction = Direction.DOWN,
841 limit_to_graph: bool = True,
842 include_start: bool = False,
843 ) -> Iterator[T]:
844 """
845 Generic depth-first traversal utility.
847 Args:
848 start_node (T): Node to start from.
849 direction: Direction.UP (depends_on) or Direction.DOWN (dependents).
850 limit_to_graph: If True, only return nodes that are members of this graph.
851 include_start: If True, yield the start_node first.
853 Yields:
854 T: Each reached node.
855 """
856 visited: set[T] = {start_node}
858 if include_start:
859 if not limit_to_graph or start_node in self._nodes:
860 yield start_node
862 def discover(current: T) -> Iterator[T]:
863 neighbors = (
864 current.dependents
865 if direction == Direction.DOWN
866 else current.depends_on
867 )
868 for neighbor in neighbors:
869 if neighbor not in visited:
870 if limit_to_graph and neighbor not in self._nodes:
871 continue
872 visited.add(neighbor)
873 yield neighbor
874 yield from discover(neighbor)
876 yield from discover(start_node)
878 @property
879 def sinks(self) -> list[T]:
880 """
881 Get all sink nodes (nodes with no dependents).
883 Returns:
884 list[T]: A list of sink nodes.
885 """
886 return [node for node in self._nodes if 0 == len(node.dependents)]
888 @property
889 def sources(self) -> list[T]:
890 """
891 Get all source nodes (nodes with no dependencies).
893 Returns:
894 list[T]: A list of source nodes.
895 """
896 return [node for node in self._nodes if 0 == len(node.depends_on)]
898 @staticmethod
899 def parse(
900 parser_fnc: Callable[..., Graph[Any]], source: str | Path, **kwargs: Any
901 ) -> Graph[Any]:
902 """
903 Parse a graph from a source using a parser function.
905 Args:
906 parser_fnc: The parser function to use (e.g., load_graph_json).
907 source: The source to parse (string or path).
908 **kwargs: Additional arguments passed to the parser function.
910 Returns:
911 Graph: A new Graph instance.
912 """
913 return parser_fnc(source, **kwargs)
915 @classmethod
916 def from_csv(cls, source: str | Path, **kwargs: Any) -> Graph[Any]:
917 """Create a Graph from a CSV edge list."""
918 from .parsers.csv import load_graph_csv
920 return cls.parse(load_graph_csv, source, **kwargs)
922 @classmethod
923 def from_graphml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]:
924 """Create a Graph from a GraphML file or string."""
925 from .parsers.graphml import load_graph_graphml
927 return cls.parse(load_graph_graphml, source, **kwargs)
929 @classmethod
930 def from_json(cls, source: str | Path, **kwargs: Any) -> Graph[Any]:
931 """Create a Graph from a JSON file or string."""
932 from .parsers.json import load_graph_json
934 return cls.parse(load_graph_json, source, **kwargs)
936 @classmethod
937 def from_toml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]:
938 """Create a Graph from a TOML file or string."""
939 from .parsers.toml import load_graph_toml
941 return cls.parse(load_graph_toml, source, **kwargs)
943 @classmethod
944 def from_yaml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]:
945 """Create a Graph from a YAML file or string."""
946 from .parsers.yaml import load_graph_yaml
948 return cls.parse(load_graph_yaml, source, **kwargs)
950 def subgraph_filtered(self, fn: Callable[[T], bool]) -> Graph[T]:
951 """
952 Create a new subgraph containing only nodes that satisfy the predicate.
954 Args:
955 fn (Callable[[T], bool]): The predicate function.
957 Returns:
958 Graph[T]: A new Graph containing the filtered nodes.
959 """
960 logger.debug("Creating filtered subgraph.")
961 return Graph([node for node in self._nodes if fn(node)], discover=True)
963 def subgraph_tagged(self, tag: str) -> Graph[T]:
964 """
965 Create a new subgraph containing only nodes with the specified tag.
967 Args:
968 tag (str): The tag to filter by.
970 Returns:
971 Graph[T]: A new Graph containing the tagged nodes.
972 """
973 logger.debug(f"Creating subgraph for tag: {tag}")
974 return Graph(
975 [node for node in self._nodes if node.is_tagged(tag)], discover=True
976 )
978 def upstream_of(self, node: T) -> Graph[T]:
979 """
980 Create a new graph containing the given node and all its ancestors.
982 Args:
983 node (T): The node to start from.
985 Returns:
986 Graph[T]: A new Graph instance.
987 """
988 if node not in self._nodes:
989 raise KeyError(f"Node '{node.reference}' not found in graph.")
991 nodes = {node} | set(self.ancestors(node))
992 return Graph(nodes)
994 def downstream_of(self, node: T) -> Graph[T]:
995 """
996 Create a new graph containing the given node and all its descendants.
998 Args:
999 node (T): The node to start from.
1001 Returns:
1002 Graph[T]: A new Graph instance.
1003 """
1004 if node not in self._nodes:
1005 raise KeyError(f"Node '{node.reference}' not found in graph.")
1007 nodes = {node} | set(self.descendants(node))
1008 return Graph(nodes)
1010 def cpm_analysis(self) -> dict[T, dict[str, float]]:
1011 """
1012 Perform Critical Path Method (CPM) analysis on the graph.
1013 Assumes all nodes have a 'duration' attribute.
1015 Returns:
1016 dict[T, dict[str, float]]: A dictionary mapping each node to its CPM values:
1017 - 'ES': Earliest Start
1018 - 'EF': Earliest Finish
1019 - 'LS': Latest Start
1020 - 'LF': Latest Finish
1021 - 'slack': Total Slack (LF - EF)
1022 """
1023 logger.debug("Starting CPM analysis.")
1024 topo_order = self.topological_order()
1025 if not topo_order:
1026 return {}
1028 analysis: dict[T, dict[str, float]] = {
1029 node: {"ES": 0.0, "EF": 0.0, "LS": 0.0, "LF": 0.0, "slack": 0.0}
1030 for node in topo_order
1031 }
1033 # 1. Forward Pass (ES, EF)
1034 for node in topo_order:
1035 max_ef = 0.0
1036 for dep in node.depends_on:
1037 if dep in analysis:
1038 max_ef = max(max_ef, analysis[dep]["EF"])
1039 analysis[node]["ES"] = max_ef
1040 analysis[node]["EF"] = max_ef + node.duration
1042 # 2. Backward Pass (LF, LS)
1043 max_total_ef = max(analysis[node]["EF"] for node in topo_order)
1045 for node in reversed(topo_order):
1046 if not node.dependents or all(d not in analysis for d in node.dependents):
1047 min_ls = max_total_ef
1048 else:
1049 min_ls = min(
1050 analysis[dep]["LS"] for dep in node.dependents if dep in analysis
1051 )
1053 analysis[node]["LF"] = min_ls
1054 analysis[node]["LS"] = min_ls - node.duration
1055 analysis[node]["slack"] = analysis[node]["LF"] - analysis[node]["EF"]
1057 return analysis
1059 def critical_path(self) -> list[T]:
1060 """
1061 Identify the nodes on the critical path (slack == 0).
1063 Returns:
1064 list[T]: A list of nodes on the critical path, in topological order.
1065 """
1066 analysis = self.cpm_analysis()
1067 return [
1068 node
1069 for node in self.topological_order()
1070 if abs(analysis[node]["slack"]) < 1e-9
1071 ]
1073 def longest_path(self) -> list[T]:
1074 """
1075 Find the longest path in the graph based on node durations.
1076 In a DAG, this is equivalent to the critical path chain.
1078 Returns:
1079 list[T]: The nodes forming the longest path.
1080 """
1081 # This is a bit more complex than just critical_path() if there are multiple critical paths.
1082 # But for dependency graphs, any path where slack == 0 is "a" longest path.
1083 # To get a specific chain:
1084 analysis = self.cpm_analysis()
1085 cp_nodes = {
1086 node for node, vals in analysis.items() if abs(vals["slack"]) < 1e-9
1087 }
1089 if not cp_nodes:
1090 return []
1092 # Find a source on critical path
1093 current = None
1094 for node in self.sources:
1095 if node in cp_nodes:
1096 current = node
1097 break
1099 if current is None:
1100 # Fallback: just take the first CP node in topo order
1101 current = sorted(
1102 list(cp_nodes), key=lambda n: self.topological_order().index(n)
1103 )[0]
1105 path = [current]
1106 while True:
1107 next_node = None
1108 # Find a dependent that is also on critical path and continues the timing
1109 for dep in current.dependents:
1110 if (
1111 dep in cp_nodes
1112 and abs(analysis[dep]["ES"] - analysis[current]["EF"]) < 1e-9
1113 ):
1114 next_node = dep
1115 break
1116 if next_node:
1117 path.append(next_node)
1118 current = next_node
1119 else:
1120 break
1121 return path
1123 def all_paths(self, source: T, target: T) -> list[list[T]]:
1124 """
1125 Find all possible paths between two nodes.
1127 Args:
1128 source (T): Starting node.
1129 target (T): Ending node.
1131 Returns:
1132 list[list[T]]: A list of all paths, where each path is a list of nodes.
1133 """
1135 def find_all_paths(current: T, goal: T, path: list[T]) -> list[list[T]]:
1136 path = path + [current]
1137 if current == goal:
1138 return [path]
1139 paths = []
1140 for neighbor in current.dependents:
1141 if neighbor in self._nodes:
1142 new_paths = find_all_paths(neighbor, goal, path)
1143 for p in new_paths:
1144 paths.append(p)
1145 return paths
1147 return find_all_paths(source, target, [])
1149 def diff(self, other: Graph[T]) -> dict[str, Any]:
1150 """
1151 Compare this graph with another graph.
1153 Returns:
1154 dict[str, Any]: A dictionary containing differences:
1155 - 'added_nodes': references of nodes in other but not in self.
1156 - 'removed_nodes': references of nodes in self but not in other.
1157 - 'modified_nodes': references of nodes in both but with different properties.
1158 - 'added_edges': (u, v) tuples of edges in other but not in self.
1159 - 'removed_edges': (u, v) tuples of edges in self but not in other.
1160 - 'modified_edges': (u, v) tuples of edges in both but with different attributes.
1161 """
1162 self_refs = {node.reference for node in self._nodes}
1163 other_refs = {node.reference for node in other._nodes}
1165 added_nodes = other_refs - self_refs
1166 removed_nodes = self_refs - other_refs
1168 modified_nodes = set()
1169 for ref in self_refs & other_refs:
1170 n1 = self[ref]
1171 n2 = other[ref]
1172 if (
1173 n1.tags != n2.tags
1174 or n1.duration != n2.duration
1175 or n1.status != n2.status
1176 ):
1177 modified_nodes.add(ref)
1179 def get_edges(g: Graph[T]):
1180 edges = {}
1181 for u in g._nodes:
1182 for v in u.dependents:
1183 if v in g._nodes:
1184 edges[(u.reference, v.reference)] = u.edge_attributes(v)
1185 return edges
1187 self_edges = get_edges(self)
1188 other_edges = get_edges(other)
1190 self_edge_set = set(self_edges.keys())
1191 other_edge_set = set(other_edges.keys())
1193 added_edges = other_edge_set - self_edge_set
1194 removed_edges = self_edge_set - other_edge_set
1195 modified_edges = set()
1197 for edge in self_edge_set & other_edge_set:
1198 if self_edges[edge] != other_edges[edge]:
1199 modified_edges.add(edge)
1201 return {
1202 "added_nodes": added_nodes,
1203 "removed_nodes": removed_nodes,
1204 "modified_nodes": modified_nodes,
1205 "added_edges": added_edges,
1206 "removed_edges": removed_edges,
1207 "modified_edges": modified_edges,
1208 }
1210 def topological_order(self) -> list[T]:
1211 """
1212 Get the nodes in topological order.
1213 Only nodes that are members of this graph are included.
1215 Returns:
1216 list[T]: A list of member nodes sorted topologically.
1217 """
1218 if self._topological_order is None:
1219 logger.debug("Calculating topological order.")
1220 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes})
1221 # Filter the static order to only include nodes that are in this graph
1222 self._topological_order = [
1223 node for node in sorter.static_order() if node in self._nodes
1224 ]
1226 return self._topological_order
1228 def topological_order_filtered(self, fn: Callable[[T], bool]) -> list[T]:
1229 """
1230 Get a filtered list of nodes in topological order.
1232 Args:
1233 fn (Callable[[T], bool]): The predicate function.
1235 Returns:
1236 list[T]: Filtered topologically sorted nodes.
1237 """
1238 return [node for node in self.topological_order() if fn(node)]
1240 def topological_order_tagged(self, tag: str) -> list[T]:
1241 """
1242 Get a list of nodes with a specific tag in topological order.
1244 Args:
1245 tag (str): The tag to filter by.
1247 Returns:
1248 list[T]: Tagged topologically sorted nodes.
1249 """
1250 return [node for node in self.topological_order() if node.is_tagged(tag)]
1252 def to_networkx(self):
1253 """
1254 Convert this graph to a networkx.DiGraph.
1255 Requires 'networkx' to be installed.
1257 Returns:
1258 networkx.DiGraph: The converted directed graph.
1259 """
1260 from .views.networkx import to_networkx
1262 return to_networkx(self)
1264 def transitive_reduction(self) -> Graph[T]:
1265 """
1266 Compute the transitive reduction of this DAG.
1267 A transitive reduction of a directed acyclic graph G is a graph G' with the same nodes
1268 and the same reachability as G, but with as few edges as possible.
1270 Returns:
1271 Graph[T]: A new Graph instance containing the same nodes (cloned) but with redundant edges removed.
1272 """
1273 import copy
1275 logger.debug("Calculating transitive reduction.")
1277 # 1. Clone nodes without edges to avoid modifying the original graph.
1278 node_map: dict[T, T] = {}
1279 for node in self._nodes:
1280 new_node = copy.copy(node)
1281 # Reset internal edge tracking
1282 new_node._dependents = {}
1283 new_node._depends_on = {}
1284 # Manually clone tags to avoid shared state
1285 new_node._tags = set(node.tags)
1286 node_map[node] = new_node
1288 # 2. Identify redundant edges.
1289 # An edge (u, v) is redundant if there exists a path from u to v of length > 1.
1290 redundant_edges: set[tuple[T, T]] = set()
1291 for u in self._nodes:
1292 for v in u.dependents:
1293 # Check if v is reachable from u through any other neighbor w.
1294 if any(w.find_path(v) for w in u.dependents if w != v):
1295 redundant_edges.add((u, v))
1297 # 3. Construct the new graph with non-redundant edges.
1298 new_graph = Graph(set(node_map.values()))
1299 for u in self._nodes:
1300 for v in u.dependents:
1301 if (u, v) not in redundant_edges:
1302 # Preserve edge attributes
1303 attrs = u.edge_attributes(v)
1304 new_graph.add_edge(node_map[u], node_map[v], **attrs)
1306 logger.info(
1307 f"Transitive reduction complete. Removed {len(redundant_edges)} redundant edges."
1308 )
1309 return new_graph
1311 def render(
1312 self,
1313 view_fnc: Callable[..., str],
1314 transitive_reduction: bool = False,
1315 **kwargs: Any,
1316 ) -> str:
1317 """
1318 Render the graph using a view function.
1320 Args:
1321 view_fnc: The view function to use (e.g., create_topology_mermaid_mmd).
1322 transitive_reduction: If True, render the transitive reduction of the graph.
1323 **kwargs: Additional arguments passed to the view function.
1325 Returns:
1326 str: The rendered representation.
1327 """
1328 target = self.transitive_reduction() if transitive_reduction else self
1329 return view_fnc(target, **kwargs)
1331 def export(
1332 self,
1333 export_fnc: Callable[..., None],
1334 output: Path | str,
1335 transitive_reduction: bool = False,
1336 embed_checksum: bool = False,
1337 **kwargs: Any,
1338 ) -> None:
1339 """
1340 Export the graph using an export function.
1342 Args:
1343 export_fnc: The export function to use (e.g., export_topology_graphviz_svg).
1344 output: The output file path.
1345 transitive_reduction: If True, export the transitive reduction of the graph.
1346 embed_checksum: If True, embed the graph's checksum as a comment at the top.
1347 **kwargs: Additional arguments passed to the export function.
1348 """
1349 from pathlib import Path
1351 from .registry import CREATOR_MAP
1352 from .views.utils import wrap_with_checksum
1354 p = Path(output)
1355 target = self.transitive_reduction() if transitive_reduction else self
1357 if not embed_checksum:
1358 return export_fnc(target, p, **kwargs)
1360 # To embed checksum, we need to capture the output string first.
1361 create_fnc = CREATOR_MAP.get(export_fnc)
1363 if not create_fnc:
1364 # Fallback: export normally if we can't find a string-generating version
1365 export_name = getattr(export_fnc, "__name__", str(export_fnc))
1366 logger.warning(
1367 f"Could not find string-generating version of {export_name}. Exporting normally without checksum embedding."
1368 )
1369 return export_fnc(target, p, **kwargs)
1371 content = create_fnc(target, **kwargs)
1372 checksum = target.checksum()
1373 wrapped = wrap_with_checksum(content, checksum, p.suffix)
1375 with open(p, "w+") as f:
1376 f.write(wrapped)