Source code for graphable.graph

from __future__ import annotations

from graphlib import CycleError, TopologicalSorter
from hashlib import blake2b
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Iterator

from .enums import Direction, Engine
from .errors import GraphConsistencyError, GraphCycleError
from .graphable import Graphable

logger = getLogger(__name__)


[docs] class Graph[T: Graphable[Any]]: """ Represents a graph of Graphable nodes. """
[docs] def __init__(self, initial: set[T] | list[T] | None = None, discover: bool = False): """ Initialize a Graph. Args: initial (set[T] | list[T] | None): An optional set of initial nodes. discover (bool): If True, automatically expand the graph to include all reachable ancestors and descendants of the initial nodes. Raises: GraphCycleError: If the initial set of nodes contains a cycle. """ self._nodes: set[T] = set() self._topological_order: list[T] | None = None self._parallel_topological_order: list[set[T]] | None = None self._checksum: str | None = None if initial: for node in initial: self.add_node(node) if discover: self.discover() self.check_consistency() self.check_cycles()
[docs] def clone(self, include_edges: bool = False) -> Graph[T]: """ Create a copy of this graph. Args: include_edges: If True, the new graph will have the same edges as this one. If False, the new graph will contain copies of all nodes but with no edges between them. Returns: Graph[T]: A new Graph instance. """ import copy logger.debug(f"Cloning graph (include_edges={include_edges}).") node_map: dict[T, T] = {} for node in self._nodes: new_node = copy.copy(node) # Reset internal edge tracking new_node._dependents = {} new_node._depends_on = {} # Manually clone tags to avoid shared state new_node._tags = set(node.tags) node_map[node] = new_node new_graph = Graph(set(node_map.values())) if include_edges: for u in self._nodes: for v, attrs in self.neighbors(u, Direction.DOWN): new_graph.add_edge(node_map[u], node_map[v], **attrs) return new_graph
[docs] def neighbors( self, node: T, direction: Direction = Direction.DOWN ) -> Iterator[tuple[T, dict[str, Any]]]: """ Iterate over neighbors within this graph. Args: node (T): The source node. direction: Direction.DOWN for dependents, Direction.UP for dependencies. Yields: tuple[T, dict[str, Any]]: A (neighbor_node, edge_attributes) tuple. """ neighbors = node.dependents if direction == Direction.DOWN else node.depends_on for neighbor in neighbors: if neighbor in self._nodes: attrs = ( node.edge_attributes(neighbor) if direction == Direction.DOWN else neighbor.edge_attributes(node) ) yield neighbor, attrs
[docs] def internal_dependents(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]: """Alias for neighbors(node, Direction.DOWN).""" return self.neighbors(node, Direction.DOWN)
[docs] def internal_depends_on(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]: """Alias for neighbors(node, Direction.UP).""" return self.neighbors(node, Direction.UP)
[docs] def discover(self) -> None: """ Traverse the entire connectivity of the current nodes and add any reachable ancestors or descendants that are not yet members. """ logger.debug(f"Discovering reachable nodes from {len(self._nodes)} base nodes.") new_nodes: set[T] = set() for node in list(self._nodes): new_nodes.update( self._traverse(node, direction=Direction.UP, limit_to_graph=False) ) new_nodes.update( self._traverse(node, direction=Direction.DOWN, limit_to_graph=False) ) for node in new_nodes: self.add_node(node)
def _invalidate_cache(self) -> None: """Clear all cached calculations for this graph.""" logger.debug("Invalidating graph cache.") self._topological_order = None self._parallel_topological_order = None self._checksum = None def __contains__(self, item: object) -> bool: """ Check if a node or its reference is in the graph. Args: item (object): Either a Graphable node or a reference object. Returns: bool: True if present, False otherwise. """ if isinstance(item, Graphable): return item in self._nodes return any(node.reference == item for node in self._nodes) def __getitem__(self, reference: Any) -> T: """ Get a node by its reference. Args: reference (Any): The reference object to search for. Returns: T: The Graphable node. Raises: KeyError: If no node with the given reference exists. """ for node in self._nodes: if node.reference == reference: return node raise KeyError(f"No node found with reference: {reference}") def __iter__(self): """ Iterate over nodes in topological order. """ return iter(self.topological_order()) def __len__(self) -> int: """ Get the number of nodes in the graph. """ return len(self._nodes)
[docs] def is_equal_to(self, other: object) -> bool: """ Check if this graph is equal to another graph. Equality is defined as having the same checksum (structural and metadata-wise). Args: other: The other object to compare with. Returns: bool: True if equal, False otherwise. """ if not isinstance(other, Graph): return False return self.checksum() == other.checksum()
[docs] def checksum(self) -> str: """ Calculate a deterministic BLAKE2b checksum of the graph. The checksum accounts for all member nodes (references, tags, duration, status) and edges (including attributes) between them. External nodes are excluded. Returns: str: The hexadecimal digest of the graph. """ if self._checksum is not None: return self._checksum # 1. Sort nodes by reference to ensure deterministic iteration sorted_nodes = sorted(self._nodes, key=lambda n: str(n.reference)) hasher = blake2b() for node in sorted_nodes: # 2. Add node reference, duration, and status hasher.update(str(node.reference).encode()) hasher.update(f":duration:{node.duration}".encode()) hasher.update(f":status:{node.status}".encode()) # 3. Add sorted tags for tag in sorted(node.tags): hasher.update(f":tag:{tag}".encode()) # 4. Add sorted dependents (edges) with attributes - Only those in the graph internal_dependents = sorted( [d for d in node.dependents if d in self._nodes], key=lambda n: str(n.reference), ) for dep in internal_dependents: hasher.update(f":edge:{dep.reference}".encode()) # Add edge attributes deterministically attrs = node.edge_attributes(dep) for key in sorted(attrs.keys()): hasher.update(f":attr:{key}:{attrs[key]}".encode()) self._checksum = hasher.hexdigest() return self._checksum
[docs] def validate_checksum(self, expected: str) -> bool: """ Validate the graph against an expected checksum. Args: expected (str): The expected BLAKE2b hexadecimal digest. Returns: bool: True if the checksums match, False otherwise. """ return self.checksum() == expected
[docs] def write_checksum(self, path: Path | str) -> None: """ Write the graph's current checksum to a file. Args: path: Path to the output checksum file. """ p = Path(path) digest = self.checksum() logger.info(f"Writing checksum to: {p}") with open(p, "w+") as f: f.write(digest)
[docs] @staticmethod def read_checksum(path: Path | str) -> str: """ Read a checksum from a file. Args: path: Path to the checksum file. Returns: str: The checksum string. """ p = Path(path) logger.debug(f"Reading checksum from: {p}") with open(p, "r") as f: return f.read().strip()
[docs] @classmethod def read(cls, path: Path | str, **kwargs: Any) -> Graph[Any]: """Read a graph from a file, automatically detecting the format.""" from .parsers.utils import extract_checksum from .registry import PARSERS p = Path(path) ext = p.suffix.lower() parser = PARSERS.get(ext) if not parser: raise ValueError(f"Unsupported extension for reading: {ext}") g = parser(p, **kwargs) if embedded := extract_checksum(p): if not g.validate_checksum(embedded): raise ValueError(f"Checksum validation failed for {p}") return g
[docs] def write( self, path: Path | str, transitive_reduction: bool = False, embed_checksum: bool = False, engine: Engine | str | None = None, **kwargs: Any, ) -> None: """ Write the graph to a file, automatically detecting the format. Args: path: Path to the output file. transitive_reduction: If True, perform transitive reduction before writing. embed_checksum: If True, embed a BLAKE2b checksum in the output. engine: The rendering engine to use for images (.svg, .png). If None, it will be auto-detected. **kwargs: Additional arguments passed to the specific exporter. """ from .registry import EXPORTERS p = Path(path) ext = p.suffix.lower() # Handle images specifically to allow engine selection/auto-detection if ext in (".svg", ".png"): from .views.utils import get_image_exporter exporter = get_image_exporter(engine) else: exporter = EXPORTERS.get(ext) if not exporter: raise ValueError(f"Unsupported extension: {ext}") return self.export(exporter, p, transitive_reduction, embed_checksum, **kwargs)
[docs] def parallelized_topological_order(self) -> list[set[T]]: """ Get the nodes in topological order, grouped into sets that can be processed in parallel. Only nodes that are members of this graph are included. Returns: list[set[T]]: A list of sets of member nodes that have no unmet dependencies. """ if self._parallel_topological_order is None: logger.debug("Calculating parallel topological order.") self._parallel_topological_order = [] sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) sorter.prepare() while sorter.is_active(): ready = sorter.get_ready() if not ready: break # Filter to only include nodes that are actually in this graph filtered_ready = {node for node in ready if node in self._nodes} if filtered_ready: self._parallel_topological_order.append(filtered_ready) sorter.done(*ready) return self._parallel_topological_order
[docs] def subgraph_between(self, source: T, target: T) -> Graph[T]: """ Create a new graph containing all nodes and edges on all paths between source and target. Args: source (T): The starting node. target (T): The ending node. Returns: Graph[T]: A new Graph instance. """ if source not in self._nodes or target not in self._nodes: raise KeyError("Both source and target must be in the graph.") # Nodes between U and V are nodes that are descendants of U AND ancestors of V descendants = {source} | set(self.descendants(source)) ancestors = {target} | set(self.ancestors(target)) between = descendants & ancestors return Graph(between)
[docs] def diff_graph(self, other: Graph[T]) -> Graph[T]: """ Create a visualization-friendly diff graph. - Nodes in both: grey/default - Added nodes: green - Removed nodes: red - Modified nodes/edges: yellow/orange Returns: Graph[T]: A merged graph with diff metadata. """ import copy merged_nodes_map: dict[Any, T] = {} diff_info = self.diff(other) def get_or_create(node: T, status: str) -> T: ref = node.reference if ref not in merged_nodes_map: new_node = copy.copy(node) new_node._dependents = {} new_node._depends_on = {} new_node.add_tag(f"diff:{status}") # Add visual hints color = {"added": "green", "removed": "red", "modified": "orange"}.get( status, "grey" ) new_node.add_tag(f"color:{color}") merged_nodes_map[ref] = new_node return merged_nodes_map[ref] # Add all nodes from both for node in self._nodes: status = ( "removed" if node.reference in diff_info["removed_nodes"] else "unchanged" ) if node.reference in diff_info["modified_nodes"]: status = "modified" get_or_create(node, status) for node in other._nodes: status = ( "added" if node.reference in diff_info["added_nodes"] else "unchanged" ) if node.reference in diff_info["modified_nodes"]: status = "modified" get_or_create(node, status) new_graph = Graph(set(merged_nodes_map.values())) # Add edges from self (original) for u in self._nodes: for v in u.dependents: if v not in self._nodes: continue edge = (u.reference, v.reference) if edge in diff_info["removed_edges"]: new_graph.add_edge( merged_nodes_map[u.reference], merged_nodes_map[v.reference], diff_status="removed", color="red", ) elif edge in diff_info["modified_edges"]: new_graph.add_edge( merged_nodes_map[u.reference], merged_nodes_map[v.reference], **u.edge_attributes(v), diff_status="modified", color="orange", ) else: new_graph.add_edge( merged_nodes_map[u.reference], merged_nodes_map[v.reference], **u.edge_attributes(v), ) # Add edges from other (new) for u in other._nodes: for v in u.dependents: if v not in other._nodes: continue edge = (u.reference, v.reference) if edge in diff_info["added_edges"]: new_graph.add_edge( merged_nodes_map[u.reference], merged_nodes_map[v.reference], **u.edge_attributes(v), diff_status="added", color="green", ) return new_graph
[docs] def transitive_closure(self) -> Graph[T]: """ Compute the transitive closure of this graph. An edge (u, v) exists in the transitive closure if there is a path from u to v. Returns: Graph[T]: A new Graph instance representing the transitive closure. """ import copy logger.debug("Calculating transitive closure.") node_map = {node: copy.copy(node) for node in self._nodes} for n in node_map.values(): n._dependents = {} n._depends_on = {} new_graph = Graph(set(node_map.values())) for u in self._nodes: for v in self.descendants(u): new_graph.add_edge(node_map[u], node_map[v]) return new_graph
[docs] def suggest_cycle_breaks(self) -> list[tuple[T, T]]: """ Identify a minimal set of edges to remove to make the graph a Directed Acyclic Graph (DAG). Uses a greedy heuristic. Returns: list[tuple[T, T]]: A list of (source, target) tuples representing suggested edges to remove. """ logger.debug("Suggesting cycle breaks.") # Simple heuristic: # 1. Take all nodes. # 2. Try to order them such that we maximize forward edges. # A simple way is to use the order they were added or any arbitrary order # and see which edges go 'backwards'. nodes = list(self._nodes) # We can try to be slightly smarter by using a DFS and finding back-edges back_edges = [] visited = set() stack = set() def dfs(u): visited.add(u) stack.add(u) for v in u.dependents: if v not in self._nodes: continue if v in stack: back_edges.append((u, v)) elif v not in visited: dfs(v) stack.remove(u) for node in nodes: if node not in visited: dfs(node) return back_edges
[docs] def parallelized_topological_order_filtered( self, fn: Callable[[T], bool] ) -> list[set[T]]: """ Get a filtered list of nodes in parallelized topological order. Args: fn (Callable[[T], bool]): The predicate function. Returns: list[set[T]]: Filtered sets of nodes for parallel processing. """ result = [] for group in self.parallelized_topological_order(): filtered_group = {node for node in group if fn(node)} if filtered_group: result.append(filtered_group) return result
[docs] def parallelized_topological_order_tagged(self, tag: str) -> list[set[T]]: """ Get a list of nodes with a specific tag in parallelized topological order. Args: tag (str): The tag to filter by. Returns: list[set[T]]: Tagged sets of nodes for parallel processing. """ return self.parallelized_topological_order_filtered(lambda n: n.is_tagged(tag))
def __eq__(self, other: object) -> bool: """ Compare two graphs for equality. """ return self.is_equal_to(other) def __hash__(self) -> int: """ Graphs are hashable by identity to allow them to be used in WeakSets (e.g., as observers of Graphable nodes). """ return id(self)
[docs] def check_cycles(self) -> None: """ Check for cycles in the graph. Raises: GraphCycleError: If a cycle is detected. """ try: sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) sorter.prepare() except CycleError as e: # graphlib.CycleError args: (message, cycle_tuple) cycle = list(e.args[1]) if len(e.args) > 1 else None raise GraphCycleError(f"Cycle detected in graph: {e}", cycle=cycle) from e
[docs] def check_consistency(self) -> None: """ Check for consistency between depends_on and dependents for all nodes in the graph. Raises: GraphConsistencyError: If an inconsistency is detected. """ for node in self._nodes: self._check_node_consistency(node)
def _check_node_consistency(self, node: T) -> None: """ Check for consistency between depends_on and dependents for a single node. Args: node (T): The node to check. Raises: GraphConsistencyError: If an inconsistency is detected. """ # Check dependencies: if node depends on X, X must have node as dependent for dep in node.depends_on: if node not in dep.dependents: raise GraphConsistencyError( f"Inconsistency: Node '{node.reference}' depends on '{dep.reference}', " f"but '{dep.reference}' does not list '{node.reference}' as a dependent." ) # Check dependents: if node has dependent Y, Y must depend on node for sub in node.dependents: if node not in sub.depends_on: raise GraphConsistencyError( f"Inconsistency: Node '{node.reference}' has dependent '{sub.reference}', " f"but '{sub.reference}' does not depend on '{node.reference}'." )
[docs] def add_edge(self, node: T, dependent: T, **attributes: Any) -> None: """ Add a directed edge from node to dependent. Also adds the nodes to the graph if they are not already present. Args: node (T): The source node (dependency). dependent (T): The target node (dependent). **attributes: Edge attributes (e.g., weight, label). Raises: GraphCycleError: If adding the edge would create a cycle. """ if node == dependent: raise GraphCycleError( f"Self-loop detected: node '{node.reference}' cannot depend on itself.", cycle=[node, node], ) # Check if adding this edge creates a cycle. # A cycle is created if there is already a path from 'dependent' to 'node'. if path := dependent.find_path(node): cycle = path + [dependent] raise GraphCycleError( f"Adding edge '{node.reference}' -> '{dependent.reference}' would create a cycle.", cycle=cycle, ) self.add_node(node) self.add_node(dependent) node._add_dependent(dependent, **attributes) dependent._add_depends_on(node, **attributes) logger.debug( f"Added edge: {node.reference} -> {dependent.reference} with attributes {attributes}" ) # Invalidate cache self._invalidate_cache()
[docs] def add_node(self, node: T) -> bool: """ Add a node to the graph. Args: node (T): The node to add. Returns: bool: True if the node was added (was not already present), False otherwise. Raises: GraphCycleError: If the node is part of an existing cycle. """ if node in self._nodes: return False # If the node is already part of a cycle (linked externally), adding it might be invalid # if we want to enforce DAG. if cycle := node.find_path(node): raise GraphCycleError( f"Node '{node.reference}' is part of an existing cycle.", cycle=cycle ) self._check_node_consistency(node) self._nodes.add(node) node._register_observer(self) logger.debug(f"Added node: {node.reference}") self._invalidate_cache() return True
[docs] def remove_edge(self, node: T, dependent: T) -> None: """ Remove a directed edge from node to dependent. Args: node (T): The source node. dependent (T): The target node. """ if node in self._nodes and dependent in self._nodes: node._remove_dependent(dependent) dependent._remove_depends_on(node) logger.debug(f"Removed edge: {node.reference} -> {dependent.reference}") self._invalidate_cache()
[docs] def remove_node(self, node: T) -> None: """ Remove a node and all its connected edges from the graph. Args: node (T): The node to remove. """ if node in self._nodes: # Remove from all nodes it depends on for dep in list(node.depends_on): dep._remove_dependent(node) # Remove from all nodes that depend on it for sub in list(node.dependents): sub._remove_depends_on(node) self._nodes.remove(node) node._unregister_observer(self) logger.debug(f"Removed node: {node.reference}") self._invalidate_cache()
[docs] def ancestors(self, node: T) -> Iterator[T]: """ Get an iterator for all nodes that the given node depends on, recursively. Args: node (T): The starting node. Yields: T: The next ancestor node. """ return self._traverse(node, direction=Direction.UP, include_start=False)
[docs] def descendants(self, node: T) -> Iterator[T]: """ Get an iterator for all nodes that depend on the given node, recursively. Args: node (T): The starting node. Yields: T: The next descendant node. """ return self._traverse(node, direction=Direction.DOWN, include_start=False)
[docs] def bfs( self, start_node: T, direction: Direction = Direction.DOWN, limit_to_graph: bool = True, ) -> Iterator[T]: """ Perform a breadth-first search (BFS) starting from the given node. Args: start_node (T): The node to start from. direction: Direction.UP for dependencies, Direction.DOWN for dependents. limit_to_graph: If True, only return nodes that are members of this graph. Yields: T: Each reached node in breadth-first order. """ from collections import deque if limit_to_graph and start_node not in self._nodes: return visited: set[T] = {start_node} queue: deque[T] = deque([start_node]) yield start_node while queue: current = queue.popleft() neighbors = ( current.dependents if direction == Direction.DOWN else current.depends_on ) for neighbor in neighbors: if neighbor not in visited: if limit_to_graph and neighbor not in self._nodes: continue visited.add(neighbor) yield neighbor queue.append(neighbor)
[docs] def dfs( self, start_node: T, direction: Direction = Direction.DOWN, limit_to_graph: bool = True, ) -> Iterator[T]: """ Perform a depth-first search (DFS) starting from the given node. Args: start_node (T): The node to start from. direction: Direction.UP for dependencies, Direction.DOWN for dependents. limit_to_graph: If True, only return nodes that are members of this graph. Yields: T: Each reached node in depth-first order. """ return self._traverse( start_node, direction=direction, limit_to_graph=limit_to_graph, include_start=True, )
def _traverse( self, start_node: T, direction: Direction = Direction.DOWN, limit_to_graph: bool = True, include_start: bool = False, ) -> Iterator[T]: """ Generic depth-first traversal utility. Args: start_node (T): Node to start from. direction: Direction.UP (depends_on) or Direction.DOWN (dependents). limit_to_graph: If True, only return nodes that are members of this graph. include_start: If True, yield the start_node first. Yields: T: Each reached node. """ visited: set[T] = {start_node} if include_start: if not limit_to_graph or start_node in self._nodes: yield start_node def discover(current: T) -> Iterator[T]: neighbors = ( current.dependents if direction == Direction.DOWN else current.depends_on ) for neighbor in neighbors: if neighbor not in visited: if limit_to_graph and neighbor not in self._nodes: continue visited.add(neighbor) yield neighbor yield from discover(neighbor) yield from discover(start_node) @property def sinks(self) -> list[T]: """ Get all sink nodes (nodes with no dependents). Returns: list[T]: A list of sink nodes. """ return [node for node in self._nodes if 0 == len(node.dependents)] @property def sources(self) -> list[T]: """ Get all source nodes (nodes with no dependencies). Returns: list[T]: A list of source nodes. """ return [node for node in self._nodes if 0 == len(node.depends_on)]
[docs] @staticmethod def parse( parser_fnc: Callable[..., Graph[Any]], source: str | Path, **kwargs: Any ) -> Graph[Any]: """ Parse a graph from a source using a parser function. Args: parser_fnc: The parser function to use (e.g., load_graph_json). source: The source to parse (string or path). **kwargs: Additional arguments passed to the parser function. Returns: Graph: A new Graph instance. """ return parser_fnc(source, **kwargs)
[docs] @classmethod def from_csv(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: """Create a Graph from a CSV edge list.""" from .parsers.csv import load_graph_csv return cls.parse(load_graph_csv, source, **kwargs)
[docs] @classmethod def from_graphml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: """Create a Graph from a GraphML file or string.""" from .parsers.graphml import load_graph_graphml return cls.parse(load_graph_graphml, source, **kwargs)
[docs] @classmethod def from_json(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: """Create a Graph from a JSON file or string.""" from .parsers.json import load_graph_json return cls.parse(load_graph_json, source, **kwargs)
[docs] @classmethod def from_toml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: """Create a Graph from a TOML file or string.""" from .parsers.toml import load_graph_toml return cls.parse(load_graph_toml, source, **kwargs)
[docs] @classmethod def from_yaml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: """Create a Graph from a YAML file or string.""" from .parsers.yaml import load_graph_yaml return cls.parse(load_graph_yaml, source, **kwargs)
[docs] def subgraph_filtered(self, fn: Callable[[T], bool]) -> Graph[T]: """ Create a new subgraph containing only nodes that satisfy the predicate. Args: fn (Callable[[T], bool]): The predicate function. Returns: Graph[T]: A new Graph containing the filtered nodes. """ logger.debug("Creating filtered subgraph.") return Graph([node for node in self._nodes if fn(node)], discover=True)
[docs] def subgraph_tagged(self, tag: str) -> Graph[T]: """ Create a new subgraph containing only nodes with the specified tag. Args: tag (str): The tag to filter by. Returns: Graph[T]: A new Graph containing the tagged nodes. """ logger.debug(f"Creating subgraph for tag: {tag}") return Graph( [node for node in self._nodes if node.is_tagged(tag)], discover=True )
[docs] def upstream_of(self, node: T) -> Graph[T]: """ Create a new graph containing the given node and all its ancestors. Args: node (T): The node to start from. Returns: Graph[T]: A new Graph instance. """ if node not in self._nodes: raise KeyError(f"Node '{node.reference}' not found in graph.") nodes = {node} | set(self.ancestors(node)) return Graph(nodes)
[docs] def downstream_of(self, node: T) -> Graph[T]: """ Create a new graph containing the given node and all its descendants. Args: node (T): The node to start from. Returns: Graph[T]: A new Graph instance. """ if node not in self._nodes: raise KeyError(f"Node '{node.reference}' not found in graph.") nodes = {node} | set(self.descendants(node)) return Graph(nodes)
[docs] def cpm_analysis(self) -> dict[T, dict[str, float]]: """ Perform Critical Path Method (CPM) analysis on the graph. Assumes all nodes have a 'duration' attribute. Returns: dict[T, dict[str, float]]: A dictionary mapping each node to its CPM values: - 'ES': Earliest Start - 'EF': Earliest Finish - 'LS': Latest Start - 'LF': Latest Finish - 'slack': Total Slack (LF - EF) """ logger.debug("Starting CPM analysis.") topo_order = self.topological_order() if not topo_order: return {} analysis: dict[T, dict[str, float]] = { node: {"ES": 0.0, "EF": 0.0, "LS": 0.0, "LF": 0.0, "slack": 0.0} for node in topo_order } # 1. Forward Pass (ES, EF) for node in topo_order: max_ef = 0.0 for dep in node.depends_on: if dep in analysis: max_ef = max(max_ef, analysis[dep]["EF"]) analysis[node]["ES"] = max_ef analysis[node]["EF"] = max_ef + node.duration # 2. Backward Pass (LF, LS) max_total_ef = max(analysis[node]["EF"] for node in topo_order) for node in reversed(topo_order): if not node.dependents or all(d not in analysis for d in node.dependents): min_ls = max_total_ef else: min_ls = min( analysis[dep]["LS"] for dep in node.dependents if dep in analysis ) analysis[node]["LF"] = min_ls analysis[node]["LS"] = min_ls - node.duration analysis[node]["slack"] = analysis[node]["LF"] - analysis[node]["EF"] return analysis
[docs] def critical_path(self) -> list[T]: """ Identify the nodes on the critical path (slack == 0). Returns: list[T]: A list of nodes on the critical path, in topological order. """ analysis = self.cpm_analysis() return [ node for node in self.topological_order() if abs(analysis[node]["slack"]) < 1e-9 ]
[docs] def longest_path(self) -> list[T]: """ Find the longest path in the graph based on node durations. In a DAG, this is equivalent to the critical path chain. Returns: list[T]: The nodes forming the longest path. """ # This is a bit more complex than just critical_path() if there are multiple critical paths. # But for dependency graphs, any path where slack == 0 is "a" longest path. # To get a specific chain: analysis = self.cpm_analysis() cp_nodes = { node for node, vals in analysis.items() if abs(vals["slack"]) < 1e-9 } if not cp_nodes: return [] # Find a source on critical path current = None for node in self.sources: if node in cp_nodes: current = node break if current is None: # Fallback: just take the first CP node in topo order current = sorted( list(cp_nodes), key=lambda n: self.topological_order().index(n) )[0] path = [current] while True: next_node = None # Find a dependent that is also on critical path and continues the timing for dep in current.dependents: if ( dep in cp_nodes and abs(analysis[dep]["ES"] - analysis[current]["EF"]) < 1e-9 ): next_node = dep break if next_node: path.append(next_node) current = next_node else: break return path
[docs] def all_paths(self, source: T, target: T) -> list[list[T]]: """ Find all possible paths between two nodes. Args: source (T): Starting node. target (T): Ending node. Returns: list[list[T]]: A list of all paths, where each path is a list of nodes. """ def find_all_paths(current: T, goal: T, path: list[T]) -> list[list[T]]: path = path + [current] if current == goal: return [path] paths = [] for neighbor in current.dependents: if neighbor in self._nodes: new_paths = find_all_paths(neighbor, goal, path) for p in new_paths: paths.append(p) return paths return find_all_paths(source, target, [])
[docs] def diff(self, other: Graph[T]) -> dict[str, Any]: """ Compare this graph with another graph. Returns: dict[str, Any]: A dictionary containing differences: - 'added_nodes': references of nodes in other but not in self. - 'removed_nodes': references of nodes in self but not in other. - 'modified_nodes': references of nodes in both but with different properties. - 'added_edges': (u, v) tuples of edges in other but not in self. - 'removed_edges': (u, v) tuples of edges in self but not in other. - 'modified_edges': (u, v) tuples of edges in both but with different attributes. """ self_refs = {node.reference for node in self._nodes} other_refs = {node.reference for node in other._nodes} added_nodes = other_refs - self_refs removed_nodes = self_refs - other_refs modified_nodes = set() for ref in self_refs & other_refs: n1 = self[ref] n2 = other[ref] if ( n1.tags != n2.tags or n1.duration != n2.duration or n1.status != n2.status ): modified_nodes.add(ref) def get_edges(g: Graph[T]): edges = {} for u in g._nodes: for v in u.dependents: if v in g._nodes: edges[(u.reference, v.reference)] = u.edge_attributes(v) return edges self_edges = get_edges(self) other_edges = get_edges(other) self_edge_set = set(self_edges.keys()) other_edge_set = set(other_edges.keys()) added_edges = other_edge_set - self_edge_set removed_edges = self_edge_set - other_edge_set modified_edges = set() for edge in self_edge_set & other_edge_set: if self_edges[edge] != other_edges[edge]: modified_edges.add(edge) return { "added_nodes": added_nodes, "removed_nodes": removed_nodes, "modified_nodes": modified_nodes, "added_edges": added_edges, "removed_edges": removed_edges, "modified_edges": modified_edges, }
[docs] def topological_order(self) -> list[T]: """ Get the nodes in topological order. Only nodes that are members of this graph are included. Returns: list[T]: A list of member nodes sorted topologically. """ if self._topological_order is None: logger.debug("Calculating topological order.") sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) # Filter the static order to only include nodes that are in this graph self._topological_order = [ node for node in sorter.static_order() if node in self._nodes ] return self._topological_order
[docs] def topological_order_filtered(self, fn: Callable[[T], bool]) -> list[T]: """ Get a filtered list of nodes in topological order. Args: fn (Callable[[T], bool]): The predicate function. Returns: list[T]: Filtered topologically sorted nodes. """ return [node for node in self.topological_order() if fn(node)]
[docs] def topological_order_tagged(self, tag: str) -> list[T]: """ Get a list of nodes with a specific tag in topological order. Args: tag (str): The tag to filter by. Returns: list[T]: Tagged topologically sorted nodes. """ return [node for node in self.topological_order() if node.is_tagged(tag)]
[docs] def to_networkx(self): """ Convert this graph to a networkx.DiGraph. Requires 'networkx' to be installed. Returns: networkx.DiGraph: The converted directed graph. """ from .views.networkx import to_networkx return to_networkx(self)
[docs] def transitive_reduction(self) -> Graph[T]: """ Compute the transitive reduction of this DAG. A transitive reduction of a directed acyclic graph G is a graph G' with the same nodes and the same reachability as G, but with as few edges as possible. Returns: Graph[T]: A new Graph instance containing the same nodes (cloned) but with redundant edges removed. """ import copy logger.debug("Calculating transitive reduction.") # 1. Clone nodes without edges to avoid modifying the original graph. node_map: dict[T, T] = {} for node in self._nodes: new_node = copy.copy(node) # Reset internal edge tracking new_node._dependents = {} new_node._depends_on = {} # Manually clone tags to avoid shared state new_node._tags = set(node.tags) node_map[node] = new_node # 2. Identify redundant edges. # An edge (u, v) is redundant if there exists a path from u to v of length > 1. redundant_edges: set[tuple[T, T]] = set() for u in self._nodes: for v in u.dependents: # Check if v is reachable from u through any other neighbor w. if any(w.find_path(v) for w in u.dependents if w != v): redundant_edges.add((u, v)) # 3. Construct the new graph with non-redundant edges. new_graph = Graph(set(node_map.values())) for u in self._nodes: for v in u.dependents: if (u, v) not in redundant_edges: # Preserve edge attributes attrs = u.edge_attributes(v) new_graph.add_edge(node_map[u], node_map[v], **attrs) logger.info( f"Transitive reduction complete. Removed {len(redundant_edges)} redundant edges." ) return new_graph
[docs] def render( self, view_fnc: Callable[..., str], transitive_reduction: bool = False, **kwargs: Any, ) -> str: """ Render the graph using a view function. Args: view_fnc: The view function to use (e.g., create_topology_mermaid_mmd). transitive_reduction: If True, render the transitive reduction of the graph. **kwargs: Additional arguments passed to the view function. Returns: str: The rendered representation. """ target = self.transitive_reduction() if transitive_reduction else self return view_fnc(target, **kwargs)
[docs] def export( self, export_fnc: Callable[..., None], output: Path | str, transitive_reduction: bool = False, embed_checksum: bool = False, **kwargs: Any, ) -> None: """ Export the graph using an export function. Args: export_fnc: The export function to use (e.g., export_topology_graphviz_svg). output: The output file path. transitive_reduction: If True, export the transitive reduction of the graph. embed_checksum: If True, embed the graph's checksum as a comment at the top. **kwargs: Additional arguments passed to the export function. """ from pathlib import Path from .registry import CREATOR_MAP from .views.utils import wrap_with_checksum p = Path(output) target = self.transitive_reduction() if transitive_reduction else self if not embed_checksum: return export_fnc(target, p, **kwargs) # To embed checksum, we need to capture the output string first. create_fnc = CREATOR_MAP.get(export_fnc) if not create_fnc: # Fallback: export normally if we can't find a string-generating version export_name = getattr(export_fnc, "__name__", str(export_fnc)) logger.warning( f"Could not find string-generating version of {export_name}. Exporting normally without checksum embedding." ) return export_fnc(target, p, **kwargs) content = create_fnc(target, **kwargs) checksum = target.checksum() wrapped = wrap_with_checksum(content, checksum, p.suffix) with open(p, "w+") as f: f.write(wrapped)