Coverage for src / graphable / graph.py: 96%

536 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-16 21:32 +0000

1from __future__ import annotations 

2 

3from graphlib import CycleError, TopologicalSorter 

4from hashlib import blake2b 

5from logging import getLogger 

6from pathlib import Path 

7from typing import Any, Callable, Iterator 

8 

9from .enums import Direction, Engine 

10from .errors import GraphConsistencyError, GraphCycleError 

11from .graphable import Graphable 

12 

13logger = getLogger(__name__) 

14 

15 

16class Graph[T: Graphable[Any]]: 

17 """ 

18 Represents a graph of Graphable nodes. 

19 """ 

20 

21 def __init__(self, initial: set[T] | list[T] | None = None, discover: bool = False): 

22 """ 

23 Initialize a Graph. 

24 

25 Args: 

26 initial (set[T] | list[T] | None): An optional set of initial nodes. 

27 discover (bool): If True, automatically expand the graph to include all 

28 reachable ancestors and descendants of the initial nodes. 

29 

30 Raises: 

31 GraphCycleError: If the initial set of nodes contains a cycle. 

32 """ 

33 self._nodes: set[T] = set() 

34 self._topological_order: list[T] | None = None 

35 self._parallel_topological_order: list[set[T]] | None = None 

36 self._checksum: str | None = None 

37 

38 if initial: 

39 for node in initial: 

40 self.add_node(node) 

41 

42 if discover: 

43 self.discover() 

44 

45 self.check_consistency() 

46 self.check_cycles() 

47 

48 def clone(self, include_edges: bool = False) -> Graph[T]: 

49 """ 

50 Create a copy of this graph. 

51 

52 Args: 

53 include_edges: If True, the new graph will have the same edges as this one. 

54 If False, the new graph will contain copies of all nodes but with 

55 no edges between them. 

56 

57 Returns: 

58 Graph[T]: A new Graph instance. 

59 """ 

60 import copy 

61 

62 logger.debug(f"Cloning graph (include_edges={include_edges}).") 

63 node_map: dict[T, T] = {} 

64 for node in self._nodes: 

65 new_node = copy.copy(node) 

66 # Reset internal edge tracking 

67 new_node._dependents = {} 

68 new_node._depends_on = {} 

69 # Manually clone tags to avoid shared state 

70 new_node._tags = set(node.tags) 

71 node_map[node] = new_node 

72 

73 new_graph = Graph(set(node_map.values())) 

74 

75 if include_edges: 

76 for u in self._nodes: 

77 for v, attrs in self.neighbors(u, Direction.DOWN): 

78 new_graph.add_edge(node_map[u], node_map[v], **attrs) 

79 

80 return new_graph 

81 

82 def neighbors( 

83 self, node: T, direction: Direction = Direction.DOWN 

84 ) -> Iterator[tuple[T, dict[str, Any]]]: 

85 """ 

86 Iterate over neighbors within this graph. 

87 

88 Args: 

89 node (T): The source node. 

90 direction: Direction.DOWN for dependents, Direction.UP for dependencies. 

91 

92 Yields: 

93 tuple[T, dict[str, Any]]: A (neighbor_node, edge_attributes) tuple. 

94 """ 

95 neighbors = node.dependents if direction == Direction.DOWN else node.depends_on 

96 for neighbor in neighbors: 

97 if neighbor in self._nodes: 

98 attrs = ( 

99 node.edge_attributes(neighbor) 

100 if direction == Direction.DOWN 

101 else neighbor.edge_attributes(node) 

102 ) 

103 yield neighbor, attrs 

104 

105 def internal_dependents(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]: 

106 """Alias for neighbors(node, Direction.DOWN).""" 

107 return self.neighbors(node, Direction.DOWN) 

108 

109 def internal_depends_on(self, node: T) -> Iterator[tuple[T, dict[str, Any]]]: 

110 """Alias for neighbors(node, Direction.UP).""" 

111 return self.neighbors(node, Direction.UP) 

112 

113 def discover(self) -> None: 

114 """ 

115 Traverse the entire connectivity of the current nodes and add any 

116 reachable ancestors or descendants that are not yet members. 

117 """ 

118 logger.debug(f"Discovering reachable nodes from {len(self._nodes)} base nodes.") 

119 

120 new_nodes: set[T] = set() 

121 for node in list(self._nodes): 

122 new_nodes.update( 

123 self._traverse(node, direction=Direction.UP, limit_to_graph=False) 

124 ) 

125 new_nodes.update( 

126 self._traverse(node, direction=Direction.DOWN, limit_to_graph=False) 

127 ) 

128 

129 for node in new_nodes: 

130 self.add_node(node) 

131 

132 def _invalidate_cache(self) -> None: 

133 """Clear all cached calculations for this graph.""" 

134 logger.debug("Invalidating graph cache.") 

135 self._topological_order = None 

136 self._parallel_topological_order = None 

137 self._checksum = None 

138 

139 def __contains__(self, item: object) -> bool: 

140 """ 

141 Check if a node or its reference is in the graph. 

142 

143 Args: 

144 item (object): Either a Graphable node or a reference object. 

145 

146 Returns: 

147 bool: True if present, False otherwise. 

148 """ 

149 if isinstance(item, Graphable): 

150 return item in self._nodes 

151 

152 return any(node.reference == item for node in self._nodes) 

153 

154 def __getitem__(self, reference: Any) -> T: 

155 """ 

156 Get a node by its reference. 

157 

158 Args: 

159 reference (Any): The reference object to search for. 

160 

161 Returns: 

162 T: The Graphable node. 

163 

164 Raises: 

165 KeyError: If no node with the given reference exists. 

166 """ 

167 for node in self._nodes: 

168 if node.reference == reference: 

169 return node 

170 raise KeyError(f"No node found with reference: {reference}") 

171 

172 def __iter__(self): 

173 """ 

174 Iterate over nodes in topological order. 

175 """ 

176 return iter(self.topological_order()) 

177 

178 def __len__(self) -> int: 

179 """ 

180 Get the number of nodes in the graph. 

181 """ 

182 return len(self._nodes) 

183 

184 def is_equal_to(self, other: object) -> bool: 

185 """ 

186 Check if this graph is equal to another graph. 

187 Equality is defined as having the same checksum (structural and metadata-wise). 

188 

189 Args: 

190 other: The other object to compare with. 

191 

192 Returns: 

193 bool: True if equal, False otherwise. 

194 """ 

195 if not isinstance(other, Graph): 

196 return False 

197 

198 return self.checksum() == other.checksum() 

199 

200 def checksum(self) -> str: 

201 """ 

202 Calculate a deterministic BLAKE2b checksum of the graph. 

203 The checksum accounts for all member nodes (references, tags, duration, status) 

204 and edges (including attributes) between them. External nodes are excluded. 

205 

206 Returns: 

207 str: The hexadecimal digest of the graph. 

208 """ 

209 if self._checksum is not None: 

210 return self._checksum 

211 

212 # 1. Sort nodes by reference to ensure deterministic iteration 

213 sorted_nodes = sorted(self._nodes, key=lambda n: str(n.reference)) 

214 

215 hasher = blake2b() 

216 

217 for node in sorted_nodes: 

218 # 2. Add node reference, duration, and status 

219 hasher.update(str(node.reference).encode()) 

220 hasher.update(f":duration:{node.duration}".encode()) 

221 hasher.update(f":status:{node.status}".encode()) 

222 

223 # 3. Add sorted tags 

224 for tag in sorted(node.tags): 

225 hasher.update(f":tag:{tag}".encode()) 

226 

227 # 4. Add sorted dependents (edges) with attributes - Only those in the graph 

228 internal_dependents = sorted( 

229 [d for d in node.dependents if d in self._nodes], 

230 key=lambda n: str(n.reference), 

231 ) 

232 for dep in internal_dependents: 

233 hasher.update(f":edge:{dep.reference}".encode()) 

234 # Add edge attributes deterministically 

235 attrs = node.edge_attributes(dep) 

236 for key in sorted(attrs.keys()): 

237 hasher.update(f":attr:{key}:{attrs[key]}".encode()) 

238 

239 self._checksum = hasher.hexdigest() 

240 return self._checksum 

241 

242 def validate_checksum(self, expected: str) -> bool: 

243 """ 

244 Validate the graph against an expected checksum. 

245 

246 Args: 

247 expected (str): The expected BLAKE2b hexadecimal digest. 

248 

249 Returns: 

250 bool: True if the checksums match, False otherwise. 

251 """ 

252 return self.checksum() == expected 

253 

254 def write_checksum(self, path: Path | str) -> None: 

255 """ 

256 Write the graph's current checksum to a file. 

257 

258 Args: 

259 path: Path to the output checksum file. 

260 """ 

261 p = Path(path) 

262 digest = self.checksum() 

263 logger.info(f"Writing checksum to: {p}") 

264 with open(p, "w+") as f: 

265 f.write(digest) 

266 

267 @staticmethod 

268 def read_checksum(path: Path | str) -> str: 

269 """ 

270 Read a checksum from a file. 

271 

272 Args: 

273 path: Path to the checksum file. 

274 

275 Returns: 

276 str: The checksum string. 

277 """ 

278 p = Path(path) 

279 logger.debug(f"Reading checksum from: {p}") 

280 with open(p, "r") as f: 

281 return f.read().strip() 

282 

283 @classmethod 

284 def read(cls, path: Path | str, **kwargs: Any) -> Graph[Any]: 

285 """Read a graph from a file, automatically detecting the format.""" 

286 from .parsers.utils import extract_checksum 

287 from .registry import PARSERS 

288 

289 p = Path(path) 

290 ext = p.suffix.lower() 

291 parser = PARSERS.get(ext) 

292 if not parser: 

293 raise ValueError(f"Unsupported extension for reading: {ext}") 

294 

295 g = parser(p, **kwargs) 

296 

297 if embedded := extract_checksum(p): 

298 if not g.validate_checksum(embedded): 

299 raise ValueError(f"Checksum validation failed for {p}") 

300 return g 

301 

302 def write( 

303 self, 

304 path: Path | str, 

305 transitive_reduction: bool = False, 

306 embed_checksum: bool = False, 

307 engine: Engine | str | None = None, 

308 **kwargs: Any, 

309 ) -> None: 

310 """ 

311 Write the graph to a file, automatically detecting the format. 

312 

313 Args: 

314 path: Path to the output file. 

315 transitive_reduction: If True, perform transitive reduction before writing. 

316 embed_checksum: If True, embed a BLAKE2b checksum in the output. 

317 engine: The rendering engine to use for images (.svg, .png). 

318 If None, it will be auto-detected. 

319 **kwargs: Additional arguments passed to the specific exporter. 

320 """ 

321 from .registry import EXPORTERS 

322 

323 p = Path(path) 

324 ext = p.suffix.lower() 

325 

326 # Handle images specifically to allow engine selection/auto-detection 

327 if ext in (".svg", ".png"): 

328 from .views.utils import get_image_exporter 

329 

330 exporter = get_image_exporter(engine) 

331 else: 

332 exporter = EXPORTERS.get(ext) 

333 

334 if not exporter: 

335 raise ValueError(f"Unsupported extension: {ext}") 

336 

337 return self.export(exporter, p, transitive_reduction, embed_checksum, **kwargs) 

338 

339 def parallelized_topological_order(self) -> list[set[T]]: 

340 """ 

341 Get the nodes in topological order, grouped into sets that can be processed in parallel. 

342 Only nodes that are members of this graph are included. 

343 

344 Returns: 

345 list[set[T]]: A list of sets of member nodes that have no unmet dependencies. 

346 """ 

347 if self._parallel_topological_order is None: 

348 logger.debug("Calculating parallel topological order.") 

349 self._parallel_topological_order = [] 

350 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) 

351 sorter.prepare() 

352 while sorter.is_active(): 

353 ready = sorter.get_ready() 

354 if not ready: 

355 break 

356 # Filter to only include nodes that are actually in this graph 

357 filtered_ready = {node for node in ready if node in self._nodes} 

358 if filtered_ready: 

359 self._parallel_topological_order.append(filtered_ready) 

360 sorter.done(*ready) 

361 

362 return self._parallel_topological_order 

363 

364 def subgraph_between(self, source: T, target: T) -> Graph[T]: 

365 """ 

366 Create a new graph containing all nodes and edges on all paths between source and target. 

367 

368 Args: 

369 source (T): The starting node. 

370 target (T): The ending node. 

371 

372 Returns: 

373 Graph[T]: A new Graph instance. 

374 """ 

375 if source not in self._nodes or target not in self._nodes: 

376 raise KeyError("Both source and target must be in the graph.") 

377 

378 # Nodes between U and V are nodes that are descendants of U AND ancestors of V 

379 descendants = {source} | set(self.descendants(source)) 

380 ancestors = {target} | set(self.ancestors(target)) 

381 between = descendants & ancestors 

382 

383 return Graph(between) 

384 

385 def diff_graph(self, other: Graph[T]) -> Graph[T]: 

386 """ 

387 Create a visualization-friendly diff graph. 

388 - Nodes in both: grey/default 

389 - Added nodes: green 

390 - Removed nodes: red 

391 - Modified nodes/edges: yellow/orange 

392 

393 Returns: 

394 Graph[T]: A merged graph with diff metadata. 

395 """ 

396 import copy 

397 

398 merged_nodes_map: dict[Any, T] = {} 

399 diff_info = self.diff(other) 

400 

401 def get_or_create(node: T, status: str) -> T: 

402 ref = node.reference 

403 if ref not in merged_nodes_map: 

404 new_node = copy.copy(node) 

405 new_node._dependents = {} 

406 new_node._depends_on = {} 

407 new_node.add_tag(f"diff:{status}") 

408 # Add visual hints 

409 color = {"added": "green", "removed": "red", "modified": "orange"}.get( 

410 status, "grey" 

411 ) 

412 new_node.add_tag(f"color:{color}") 

413 merged_nodes_map[ref] = new_node 

414 return merged_nodes_map[ref] 

415 

416 # Add all nodes from both 

417 for node in self._nodes: 

418 status = ( 

419 "removed" 

420 if node.reference in diff_info["removed_nodes"] 

421 else "unchanged" 

422 ) 

423 if node.reference in diff_info["modified_nodes"]: 

424 status = "modified" 

425 get_or_create(node, status) 

426 

427 for node in other._nodes: 

428 status = ( 

429 "added" if node.reference in diff_info["added_nodes"] else "unchanged" 

430 ) 

431 if node.reference in diff_info["modified_nodes"]: 

432 status = "modified" 

433 get_or_create(node, status) 

434 

435 new_graph = Graph(set(merged_nodes_map.values())) 

436 

437 # Add edges from self (original) 

438 for u in self._nodes: 

439 for v in u.dependents: 

440 if v not in self._nodes: 

441 continue 

442 edge = (u.reference, v.reference) 

443 if edge in diff_info["removed_edges"]: 

444 new_graph.add_edge( 

445 merged_nodes_map[u.reference], 

446 merged_nodes_map[v.reference], 

447 diff_status="removed", 

448 color="red", 

449 ) 

450 elif edge in diff_info["modified_edges"]: 

451 new_graph.add_edge( 

452 merged_nodes_map[u.reference], 

453 merged_nodes_map[v.reference], 

454 **u.edge_attributes(v), 

455 diff_status="modified", 

456 color="orange", 

457 ) 

458 else: 

459 new_graph.add_edge( 

460 merged_nodes_map[u.reference], 

461 merged_nodes_map[v.reference], 

462 **u.edge_attributes(v), 

463 ) 

464 

465 # Add edges from other (new) 

466 for u in other._nodes: 

467 for v in u.dependents: 

468 if v not in other._nodes: 

469 continue 

470 edge = (u.reference, v.reference) 

471 if edge in diff_info["added_edges"]: 

472 new_graph.add_edge( 

473 merged_nodes_map[u.reference], 

474 merged_nodes_map[v.reference], 

475 **u.edge_attributes(v), 

476 diff_status="added", 

477 color="green", 

478 ) 

479 

480 return new_graph 

481 

482 def transitive_closure(self) -> Graph[T]: 

483 """ 

484 Compute the transitive closure of this graph. 

485 An edge (u, v) exists in the transitive closure if there is a path from u to v. 

486 

487 Returns: 

488 Graph[T]: A new Graph instance representing the transitive closure. 

489 """ 

490 import copy 

491 

492 logger.debug("Calculating transitive closure.") 

493 node_map = {node: copy.copy(node) for node in self._nodes} 

494 for n in node_map.values(): 

495 n._dependents = {} 

496 n._depends_on = {} 

497 

498 new_graph = Graph(set(node_map.values())) 

499 for u in self._nodes: 

500 for v in self.descendants(u): 

501 new_graph.add_edge(node_map[u], node_map[v]) 

502 

503 return new_graph 

504 

505 def suggest_cycle_breaks(self) -> list[tuple[T, T]]: 

506 """ 

507 Identify a minimal set of edges to remove to make the graph a Directed Acyclic Graph (DAG). 

508 Uses a greedy heuristic. 

509 

510 Returns: 

511 list[tuple[T, T]]: A list of (source, target) tuples representing suggested edges to remove. 

512 """ 

513 logger.debug("Suggesting cycle breaks.") 

514 # Simple heuristic: 

515 # 1. Take all nodes. 

516 # 2. Try to order them such that we maximize forward edges. 

517 # A simple way is to use the order they were added or any arbitrary order 

518 # and see which edges go 'backwards'. 

519 

520 nodes = list(self._nodes) 

521 # We can try to be slightly smarter by using a DFS and finding back-edges 

522 back_edges = [] 

523 visited = set() 

524 stack = set() 

525 

526 def dfs(u): 

527 visited.add(u) 

528 stack.add(u) 

529 for v in u.dependents: 

530 if v not in self._nodes: 

531 continue 

532 if v in stack: 

533 back_edges.append((u, v)) 

534 elif v not in visited: 

535 dfs(v) 

536 stack.remove(u) 

537 

538 for node in nodes: 

539 if node not in visited: 

540 dfs(node) 

541 

542 return back_edges 

543 

544 def parallelized_topological_order_filtered( 

545 self, fn: Callable[[T], bool] 

546 ) -> list[set[T]]: 

547 """ 

548 Get a filtered list of nodes in parallelized topological order. 

549 

550 Args: 

551 fn (Callable[[T], bool]): The predicate function. 

552 

553 Returns: 

554 list[set[T]]: Filtered sets of nodes for parallel processing. 

555 """ 

556 result = [] 

557 for group in self.parallelized_topological_order(): 

558 filtered_group = {node for node in group if fn(node)} 

559 if filtered_group: 

560 result.append(filtered_group) 

561 return result 

562 

563 def parallelized_topological_order_tagged(self, tag: str) -> list[set[T]]: 

564 """ 

565 Get a list of nodes with a specific tag in parallelized topological order. 

566 

567 Args: 

568 tag (str): The tag to filter by. 

569 

570 Returns: 

571 list[set[T]]: Tagged sets of nodes for parallel processing. 

572 """ 

573 return self.parallelized_topological_order_filtered(lambda n: n.is_tagged(tag)) 

574 

575 def __eq__(self, other: object) -> bool: 

576 """ 

577 Compare two graphs for equality. 

578 """ 

579 return self.is_equal_to(other) 

580 

581 def __hash__(self) -> int: 

582 """ 

583 Graphs are hashable by identity to allow them to be used in WeakSets 

584 (e.g., as observers of Graphable nodes). 

585 """ 

586 return id(self) 

587 

588 def check_cycles(self) -> None: 

589 """ 

590 Check for cycles in the graph. 

591 

592 Raises: 

593 GraphCycleError: If a cycle is detected. 

594 """ 

595 try: 

596 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) 

597 sorter.prepare() 

598 except CycleError as e: 

599 # graphlib.CycleError args: (message, cycle_tuple) 

600 cycle = list(e.args[1]) if len(e.args) > 1 else None 

601 raise GraphCycleError(f"Cycle detected in graph: {e}", cycle=cycle) from e 

602 

603 def check_consistency(self) -> None: 

604 """ 

605 Check for consistency between depends_on and dependents for all nodes in the graph. 

606 

607 Raises: 

608 GraphConsistencyError: If an inconsistency is detected. 

609 """ 

610 for node in self._nodes: 

611 self._check_node_consistency(node) 

612 

613 def _check_node_consistency(self, node: T) -> None: 

614 """ 

615 Check for consistency between depends_on and dependents for a single node. 

616 

617 Args: 

618 node (T): The node to check. 

619 

620 Raises: 

621 GraphConsistencyError: If an inconsistency is detected. 

622 """ 

623 # Check dependencies: if node depends on X, X must have node as dependent 

624 for dep in node.depends_on: 

625 if node not in dep.dependents: 

626 raise GraphConsistencyError( 

627 f"Inconsistency: Node '{node.reference}' depends on '{dep.reference}', " 

628 f"but '{dep.reference}' does not list '{node.reference}' as a dependent." 

629 ) 

630 # Check dependents: if node has dependent Y, Y must depend on node 

631 for sub in node.dependents: 

632 if node not in sub.depends_on: 

633 raise GraphConsistencyError( 

634 f"Inconsistency: Node '{node.reference}' has dependent '{sub.reference}', " 

635 f"but '{sub.reference}' does not depend on '{node.reference}'." 

636 ) 

637 

638 def add_edge(self, node: T, dependent: T, **attributes: Any) -> None: 

639 """ 

640 Add a directed edge from node to dependent. 

641 Also adds the nodes to the graph if they are not already present. 

642 

643 Args: 

644 node (T): The source node (dependency). 

645 dependent (T): The target node (dependent). 

646 **attributes: Edge attributes (e.g., weight, label). 

647 

648 Raises: 

649 GraphCycleError: If adding the edge would create a cycle. 

650 """ 

651 if node == dependent: 

652 raise GraphCycleError( 

653 f"Self-loop detected: node '{node.reference}' cannot depend on itself.", 

654 cycle=[node, node], 

655 ) 

656 

657 # Check if adding this edge creates a cycle. 

658 # A cycle is created if there is already a path from 'dependent' to 'node'. 

659 if path := dependent.find_path(node): 

660 cycle = path + [dependent] 

661 raise GraphCycleError( 

662 f"Adding edge '{node.reference}' -> '{dependent.reference}' would create a cycle.", 

663 cycle=cycle, 

664 ) 

665 

666 self.add_node(node) 

667 self.add_node(dependent) 

668 

669 node._add_dependent(dependent, **attributes) 

670 dependent._add_depends_on(node, **attributes) 

671 logger.debug( 

672 f"Added edge: {node.reference} -> {dependent.reference} with attributes {attributes}" 

673 ) 

674 

675 # Invalidate cache 

676 self._invalidate_cache() 

677 

678 def add_node(self, node: T) -> bool: 

679 """ 

680 Add a node to the graph. 

681 

682 Args: 

683 node (T): The node to add. 

684 

685 Returns: 

686 bool: True if the node was added (was not already present), False otherwise. 

687 

688 Raises: 

689 GraphCycleError: If the node is part of an existing cycle. 

690 """ 

691 if node in self._nodes: 

692 return False 

693 

694 # If the node is already part of a cycle (linked externally), adding it might be invalid 

695 # if we want to enforce DAG. 

696 if cycle := node.find_path(node): 

697 raise GraphCycleError( 

698 f"Node '{node.reference}' is part of an existing cycle.", cycle=cycle 

699 ) 

700 

701 self._check_node_consistency(node) 

702 self._nodes.add(node) 

703 node._register_observer(self) 

704 logger.debug(f"Added node: {node.reference}") 

705 

706 self._invalidate_cache() 

707 

708 return True 

709 

710 def remove_edge(self, node: T, dependent: T) -> None: 

711 """ 

712 Remove a directed edge from node to dependent. 

713 

714 Args: 

715 node (T): The source node. 

716 dependent (T): The target node. 

717 """ 

718 if node in self._nodes and dependent in self._nodes: 

719 node._remove_dependent(dependent) 

720 dependent._remove_depends_on(node) 

721 logger.debug(f"Removed edge: {node.reference} -> {dependent.reference}") 

722 

723 self._invalidate_cache() 

724 

725 def remove_node(self, node: T) -> None: 

726 """ 

727 Remove a node and all its connected edges from the graph. 

728 

729 Args: 

730 node (T): The node to remove. 

731 """ 

732 if node in self._nodes: 

733 # Remove from all nodes it depends on 

734 for dep in list(node.depends_on): 

735 dep._remove_dependent(node) 

736 

737 # Remove from all nodes that depend on it 

738 for sub in list(node.dependents): 

739 sub._remove_depends_on(node) 

740 

741 self._nodes.remove(node) 

742 node._unregister_observer(self) 

743 logger.debug(f"Removed node: {node.reference}") 

744 

745 self._invalidate_cache() 

746 

747 def ancestors(self, node: T) -> Iterator[T]: 

748 """ 

749 Get an iterator for all nodes that the given node depends on, recursively. 

750 

751 Args: 

752 node (T): The starting node. 

753 

754 Yields: 

755 T: The next ancestor node. 

756 """ 

757 return self._traverse(node, direction=Direction.UP, include_start=False) 

758 

759 def descendants(self, node: T) -> Iterator[T]: 

760 """ 

761 Get an iterator for all nodes that depend on the given node, recursively. 

762 

763 Args: 

764 node (T): The starting node. 

765 

766 Yields: 

767 T: The next descendant node. 

768 """ 

769 return self._traverse(node, direction=Direction.DOWN, include_start=False) 

770 

771 def bfs( 

772 self, 

773 start_node: T, 

774 direction: Direction = Direction.DOWN, 

775 limit_to_graph: bool = True, 

776 ) -> Iterator[T]: 

777 """ 

778 Perform a breadth-first search (BFS) starting from the given node. 

779 

780 Args: 

781 start_node (T): The node to start from. 

782 direction: Direction.UP for dependencies, Direction.DOWN for dependents. 

783 limit_to_graph: If True, only return nodes that are members of this graph. 

784 

785 Yields: 

786 T: Each reached node in breadth-first order. 

787 """ 

788 from collections import deque 

789 

790 if limit_to_graph and start_node not in self._nodes: 

791 return 

792 

793 visited: set[T] = {start_node} 

794 queue: deque[T] = deque([start_node]) 

795 

796 yield start_node 

797 

798 while queue: 

799 current = queue.popleft() 

800 neighbors = ( 

801 current.dependents 

802 if direction == Direction.DOWN 

803 else current.depends_on 

804 ) 

805 for neighbor in neighbors: 

806 if neighbor not in visited: 

807 if limit_to_graph and neighbor not in self._nodes: 

808 continue 

809 visited.add(neighbor) 

810 yield neighbor 

811 queue.append(neighbor) 

812 

813 def dfs( 

814 self, 

815 start_node: T, 

816 direction: Direction = Direction.DOWN, 

817 limit_to_graph: bool = True, 

818 ) -> Iterator[T]: 

819 """ 

820 Perform a depth-first search (DFS) starting from the given node. 

821 

822 Args: 

823 start_node (T): The node to start from. 

824 direction: Direction.UP for dependencies, Direction.DOWN for dependents. 

825 limit_to_graph: If True, only return nodes that are members of this graph. 

826 

827 Yields: 

828 T: Each reached node in depth-first order. 

829 """ 

830 return self._traverse( 

831 start_node, 

832 direction=direction, 

833 limit_to_graph=limit_to_graph, 

834 include_start=True, 

835 ) 

836 

837 def _traverse( 

838 self, 

839 start_node: T, 

840 direction: Direction = Direction.DOWN, 

841 limit_to_graph: bool = True, 

842 include_start: bool = False, 

843 ) -> Iterator[T]: 

844 """ 

845 Generic depth-first traversal utility. 

846 

847 Args: 

848 start_node (T): Node to start from. 

849 direction: Direction.UP (depends_on) or Direction.DOWN (dependents). 

850 limit_to_graph: If True, only return nodes that are members of this graph. 

851 include_start: If True, yield the start_node first. 

852 

853 Yields: 

854 T: Each reached node. 

855 """ 

856 visited: set[T] = {start_node} 

857 

858 if include_start: 

859 if not limit_to_graph or start_node in self._nodes: 

860 yield start_node 

861 

862 def discover(current: T) -> Iterator[T]: 

863 neighbors = ( 

864 current.dependents 

865 if direction == Direction.DOWN 

866 else current.depends_on 

867 ) 

868 for neighbor in neighbors: 

869 if neighbor not in visited: 

870 if limit_to_graph and neighbor not in self._nodes: 

871 continue 

872 visited.add(neighbor) 

873 yield neighbor 

874 yield from discover(neighbor) 

875 

876 yield from discover(start_node) 

877 

878 @property 

879 def sinks(self) -> list[T]: 

880 """ 

881 Get all sink nodes (nodes with no dependents). 

882 

883 Returns: 

884 list[T]: A list of sink nodes. 

885 """ 

886 return [node for node in self._nodes if 0 == len(node.dependents)] 

887 

888 @property 

889 def sources(self) -> list[T]: 

890 """ 

891 Get all source nodes (nodes with no dependencies). 

892 

893 Returns: 

894 list[T]: A list of source nodes. 

895 """ 

896 return [node for node in self._nodes if 0 == len(node.depends_on)] 

897 

898 @staticmethod 

899 def parse( 

900 parser_fnc: Callable[..., Graph[Any]], source: str | Path, **kwargs: Any 

901 ) -> Graph[Any]: 

902 """ 

903 Parse a graph from a source using a parser function. 

904 

905 Args: 

906 parser_fnc: The parser function to use (e.g., load_graph_json). 

907 source: The source to parse (string or path). 

908 **kwargs: Additional arguments passed to the parser function. 

909 

910 Returns: 

911 Graph: A new Graph instance. 

912 """ 

913 return parser_fnc(source, **kwargs) 

914 

915 @classmethod 

916 def from_csv(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: 

917 """Create a Graph from a CSV edge list.""" 

918 from .parsers.csv import load_graph_csv 

919 

920 return cls.parse(load_graph_csv, source, **kwargs) 

921 

922 @classmethod 

923 def from_graphml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: 

924 """Create a Graph from a GraphML file or string.""" 

925 from .parsers.graphml import load_graph_graphml 

926 

927 return cls.parse(load_graph_graphml, source, **kwargs) 

928 

929 @classmethod 

930 def from_json(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: 

931 """Create a Graph from a JSON file or string.""" 

932 from .parsers.json import load_graph_json 

933 

934 return cls.parse(load_graph_json, source, **kwargs) 

935 

936 @classmethod 

937 def from_toml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: 

938 """Create a Graph from a TOML file or string.""" 

939 from .parsers.toml import load_graph_toml 

940 

941 return cls.parse(load_graph_toml, source, **kwargs) 

942 

943 @classmethod 

944 def from_yaml(cls, source: str | Path, **kwargs: Any) -> Graph[Any]: 

945 """Create a Graph from a YAML file or string.""" 

946 from .parsers.yaml import load_graph_yaml 

947 

948 return cls.parse(load_graph_yaml, source, **kwargs) 

949 

950 def subgraph_filtered(self, fn: Callable[[T], bool]) -> Graph[T]: 

951 """ 

952 Create a new subgraph containing only nodes that satisfy the predicate. 

953 

954 Args: 

955 fn (Callable[[T], bool]): The predicate function. 

956 

957 Returns: 

958 Graph[T]: A new Graph containing the filtered nodes. 

959 """ 

960 logger.debug("Creating filtered subgraph.") 

961 return Graph([node for node in self._nodes if fn(node)], discover=True) 

962 

963 def subgraph_tagged(self, tag: str) -> Graph[T]: 

964 """ 

965 Create a new subgraph containing only nodes with the specified tag. 

966 

967 Args: 

968 tag (str): The tag to filter by. 

969 

970 Returns: 

971 Graph[T]: A new Graph containing the tagged nodes. 

972 """ 

973 logger.debug(f"Creating subgraph for tag: {tag}") 

974 return Graph( 

975 [node for node in self._nodes if node.is_tagged(tag)], discover=True 

976 ) 

977 

978 def upstream_of(self, node: T) -> Graph[T]: 

979 """ 

980 Create a new graph containing the given node and all its ancestors. 

981 

982 Args: 

983 node (T): The node to start from. 

984 

985 Returns: 

986 Graph[T]: A new Graph instance. 

987 """ 

988 if node not in self._nodes: 

989 raise KeyError(f"Node '{node.reference}' not found in graph.") 

990 

991 nodes = {node} | set(self.ancestors(node)) 

992 return Graph(nodes) 

993 

994 def downstream_of(self, node: T) -> Graph[T]: 

995 """ 

996 Create a new graph containing the given node and all its descendants. 

997 

998 Args: 

999 node (T): The node to start from. 

1000 

1001 Returns: 

1002 Graph[T]: A new Graph instance. 

1003 """ 

1004 if node not in self._nodes: 

1005 raise KeyError(f"Node '{node.reference}' not found in graph.") 

1006 

1007 nodes = {node} | set(self.descendants(node)) 

1008 return Graph(nodes) 

1009 

1010 def cpm_analysis(self) -> dict[T, dict[str, float]]: 

1011 """ 

1012 Perform Critical Path Method (CPM) analysis on the graph. 

1013 Assumes all nodes have a 'duration' attribute. 

1014 

1015 Returns: 

1016 dict[T, dict[str, float]]: A dictionary mapping each node to its CPM values: 

1017 - 'ES': Earliest Start 

1018 - 'EF': Earliest Finish 

1019 - 'LS': Latest Start 

1020 - 'LF': Latest Finish 

1021 - 'slack': Total Slack (LF - EF) 

1022 """ 

1023 logger.debug("Starting CPM analysis.") 

1024 topo_order = self.topological_order() 

1025 if not topo_order: 

1026 return {} 

1027 

1028 analysis: dict[T, dict[str, float]] = { 

1029 node: {"ES": 0.0, "EF": 0.0, "LS": 0.0, "LF": 0.0, "slack": 0.0} 

1030 for node in topo_order 

1031 } 

1032 

1033 # 1. Forward Pass (ES, EF) 

1034 for node in topo_order: 

1035 max_ef = 0.0 

1036 for dep in node.depends_on: 

1037 if dep in analysis: 

1038 max_ef = max(max_ef, analysis[dep]["EF"]) 

1039 analysis[node]["ES"] = max_ef 

1040 analysis[node]["EF"] = max_ef + node.duration 

1041 

1042 # 2. Backward Pass (LF, LS) 

1043 max_total_ef = max(analysis[node]["EF"] for node in topo_order) 

1044 

1045 for node in reversed(topo_order): 

1046 if not node.dependents or all(d not in analysis for d in node.dependents): 

1047 min_ls = max_total_ef 

1048 else: 

1049 min_ls = min( 

1050 analysis[dep]["LS"] for dep in node.dependents if dep in analysis 

1051 ) 

1052 

1053 analysis[node]["LF"] = min_ls 

1054 analysis[node]["LS"] = min_ls - node.duration 

1055 analysis[node]["slack"] = analysis[node]["LF"] - analysis[node]["EF"] 

1056 

1057 return analysis 

1058 

1059 def critical_path(self) -> list[T]: 

1060 """ 

1061 Identify the nodes on the critical path (slack == 0). 

1062 

1063 Returns: 

1064 list[T]: A list of nodes on the critical path, in topological order. 

1065 """ 

1066 analysis = self.cpm_analysis() 

1067 return [ 

1068 node 

1069 for node in self.topological_order() 

1070 if abs(analysis[node]["slack"]) < 1e-9 

1071 ] 

1072 

1073 def longest_path(self) -> list[T]: 

1074 """ 

1075 Find the longest path in the graph based on node durations. 

1076 In a DAG, this is equivalent to the critical path chain. 

1077 

1078 Returns: 

1079 list[T]: The nodes forming the longest path. 

1080 """ 

1081 # This is a bit more complex than just critical_path() if there are multiple critical paths. 

1082 # But for dependency graphs, any path where slack == 0 is "a" longest path. 

1083 # To get a specific chain: 

1084 analysis = self.cpm_analysis() 

1085 cp_nodes = { 

1086 node for node, vals in analysis.items() if abs(vals["slack"]) < 1e-9 

1087 } 

1088 

1089 if not cp_nodes: 

1090 return [] 

1091 

1092 # Find a source on critical path 

1093 current = None 

1094 for node in self.sources: 

1095 if node in cp_nodes: 

1096 current = node 

1097 break 

1098 

1099 if current is None: 

1100 # Fallback: just take the first CP node in topo order 

1101 current = sorted( 

1102 list(cp_nodes), key=lambda n: self.topological_order().index(n) 

1103 )[0] 

1104 

1105 path = [current] 

1106 while True: 

1107 next_node = None 

1108 # Find a dependent that is also on critical path and continues the timing 

1109 for dep in current.dependents: 

1110 if ( 

1111 dep in cp_nodes 

1112 and abs(analysis[dep]["ES"] - analysis[current]["EF"]) < 1e-9 

1113 ): 

1114 next_node = dep 

1115 break 

1116 if next_node: 

1117 path.append(next_node) 

1118 current = next_node 

1119 else: 

1120 break 

1121 return path 

1122 

1123 def all_paths(self, source: T, target: T) -> list[list[T]]: 

1124 """ 

1125 Find all possible paths between two nodes. 

1126 

1127 Args: 

1128 source (T): Starting node. 

1129 target (T): Ending node. 

1130 

1131 Returns: 

1132 list[list[T]]: A list of all paths, where each path is a list of nodes. 

1133 """ 

1134 

1135 def find_all_paths(current: T, goal: T, path: list[T]) -> list[list[T]]: 

1136 path = path + [current] 

1137 if current == goal: 

1138 return [path] 

1139 paths = [] 

1140 for neighbor in current.dependents: 

1141 if neighbor in self._nodes: 

1142 new_paths = find_all_paths(neighbor, goal, path) 

1143 for p in new_paths: 

1144 paths.append(p) 

1145 return paths 

1146 

1147 return find_all_paths(source, target, []) 

1148 

1149 def diff(self, other: Graph[T]) -> dict[str, Any]: 

1150 """ 

1151 Compare this graph with another graph. 

1152 

1153 Returns: 

1154 dict[str, Any]: A dictionary containing differences: 

1155 - 'added_nodes': references of nodes in other but not in self. 

1156 - 'removed_nodes': references of nodes in self but not in other. 

1157 - 'modified_nodes': references of nodes in both but with different properties. 

1158 - 'added_edges': (u, v) tuples of edges in other but not in self. 

1159 - 'removed_edges': (u, v) tuples of edges in self but not in other. 

1160 - 'modified_edges': (u, v) tuples of edges in both but with different attributes. 

1161 """ 

1162 self_refs = {node.reference for node in self._nodes} 

1163 other_refs = {node.reference for node in other._nodes} 

1164 

1165 added_nodes = other_refs - self_refs 

1166 removed_nodes = self_refs - other_refs 

1167 

1168 modified_nodes = set() 

1169 for ref in self_refs & other_refs: 

1170 n1 = self[ref] 

1171 n2 = other[ref] 

1172 if ( 

1173 n1.tags != n2.tags 

1174 or n1.duration != n2.duration 

1175 or n1.status != n2.status 

1176 ): 

1177 modified_nodes.add(ref) 

1178 

1179 def get_edges(g: Graph[T]): 

1180 edges = {} 

1181 for u in g._nodes: 

1182 for v in u.dependents: 

1183 if v in g._nodes: 

1184 edges[(u.reference, v.reference)] = u.edge_attributes(v) 

1185 return edges 

1186 

1187 self_edges = get_edges(self) 

1188 other_edges = get_edges(other) 

1189 

1190 self_edge_set = set(self_edges.keys()) 

1191 other_edge_set = set(other_edges.keys()) 

1192 

1193 added_edges = other_edge_set - self_edge_set 

1194 removed_edges = self_edge_set - other_edge_set 

1195 modified_edges = set() 

1196 

1197 for edge in self_edge_set & other_edge_set: 

1198 if self_edges[edge] != other_edges[edge]: 

1199 modified_edges.add(edge) 

1200 

1201 return { 

1202 "added_nodes": added_nodes, 

1203 "removed_nodes": removed_nodes, 

1204 "modified_nodes": modified_nodes, 

1205 "added_edges": added_edges, 

1206 "removed_edges": removed_edges, 

1207 "modified_edges": modified_edges, 

1208 } 

1209 

1210 def topological_order(self) -> list[T]: 

1211 """ 

1212 Get the nodes in topological order. 

1213 Only nodes that are members of this graph are included. 

1214 

1215 Returns: 

1216 list[T]: A list of member nodes sorted topologically. 

1217 """ 

1218 if self._topological_order is None: 

1219 logger.debug("Calculating topological order.") 

1220 sorter = TopologicalSorter({node: node.depends_on for node in self._nodes}) 

1221 # Filter the static order to only include nodes that are in this graph 

1222 self._topological_order = [ 

1223 node for node in sorter.static_order() if node in self._nodes 

1224 ] 

1225 

1226 return self._topological_order 

1227 

1228 def topological_order_filtered(self, fn: Callable[[T], bool]) -> list[T]: 

1229 """ 

1230 Get a filtered list of nodes in topological order. 

1231 

1232 Args: 

1233 fn (Callable[[T], bool]): The predicate function. 

1234 

1235 Returns: 

1236 list[T]: Filtered topologically sorted nodes. 

1237 """ 

1238 return [node for node in self.topological_order() if fn(node)] 

1239 

1240 def topological_order_tagged(self, tag: str) -> list[T]: 

1241 """ 

1242 Get a list of nodes with a specific tag in topological order. 

1243 

1244 Args: 

1245 tag (str): The tag to filter by. 

1246 

1247 Returns: 

1248 list[T]: Tagged topologically sorted nodes. 

1249 """ 

1250 return [node for node in self.topological_order() if node.is_tagged(tag)] 

1251 

1252 def to_networkx(self): 

1253 """ 

1254 Convert this graph to a networkx.DiGraph. 

1255 Requires 'networkx' to be installed. 

1256 

1257 Returns: 

1258 networkx.DiGraph: The converted directed graph. 

1259 """ 

1260 from .views.networkx import to_networkx 

1261 

1262 return to_networkx(self) 

1263 

1264 def transitive_reduction(self) -> Graph[T]: 

1265 """ 

1266 Compute the transitive reduction of this DAG. 

1267 A transitive reduction of a directed acyclic graph G is a graph G' with the same nodes 

1268 and the same reachability as G, but with as few edges as possible. 

1269 

1270 Returns: 

1271 Graph[T]: A new Graph instance containing the same nodes (cloned) but with redundant edges removed. 

1272 """ 

1273 import copy 

1274 

1275 logger.debug("Calculating transitive reduction.") 

1276 

1277 # 1. Clone nodes without edges to avoid modifying the original graph. 

1278 node_map: dict[T, T] = {} 

1279 for node in self._nodes: 

1280 new_node = copy.copy(node) 

1281 # Reset internal edge tracking 

1282 new_node._dependents = {} 

1283 new_node._depends_on = {} 

1284 # Manually clone tags to avoid shared state 

1285 new_node._tags = set(node.tags) 

1286 node_map[node] = new_node 

1287 

1288 # 2. Identify redundant edges. 

1289 # An edge (u, v) is redundant if there exists a path from u to v of length > 1. 

1290 redundant_edges: set[tuple[T, T]] = set() 

1291 for u in self._nodes: 

1292 for v in u.dependents: 

1293 # Check if v is reachable from u through any other neighbor w. 

1294 if any(w.find_path(v) for w in u.dependents if w != v): 

1295 redundant_edges.add((u, v)) 

1296 

1297 # 3. Construct the new graph with non-redundant edges. 

1298 new_graph = Graph(set(node_map.values())) 

1299 for u in self._nodes: 

1300 for v in u.dependents: 

1301 if (u, v) not in redundant_edges: 

1302 # Preserve edge attributes 

1303 attrs = u.edge_attributes(v) 

1304 new_graph.add_edge(node_map[u], node_map[v], **attrs) 

1305 

1306 logger.info( 

1307 f"Transitive reduction complete. Removed {len(redundant_edges)} redundant edges." 

1308 ) 

1309 return new_graph 

1310 

1311 def render( 

1312 self, 

1313 view_fnc: Callable[..., str], 

1314 transitive_reduction: bool = False, 

1315 **kwargs: Any, 

1316 ) -> str: 

1317 """ 

1318 Render the graph using a view function. 

1319 

1320 Args: 

1321 view_fnc: The view function to use (e.g., create_topology_mermaid_mmd). 

1322 transitive_reduction: If True, render the transitive reduction of the graph. 

1323 **kwargs: Additional arguments passed to the view function. 

1324 

1325 Returns: 

1326 str: The rendered representation. 

1327 """ 

1328 target = self.transitive_reduction() if transitive_reduction else self 

1329 return view_fnc(target, **kwargs) 

1330 

1331 def export( 

1332 self, 

1333 export_fnc: Callable[..., None], 

1334 output: Path | str, 

1335 transitive_reduction: bool = False, 

1336 embed_checksum: bool = False, 

1337 **kwargs: Any, 

1338 ) -> None: 

1339 """ 

1340 Export the graph using an export function. 

1341 

1342 Args: 

1343 export_fnc: The export function to use (e.g., export_topology_graphviz_svg). 

1344 output: The output file path. 

1345 transitive_reduction: If True, export the transitive reduction of the graph. 

1346 embed_checksum: If True, embed the graph's checksum as a comment at the top. 

1347 **kwargs: Additional arguments passed to the export function. 

1348 """ 

1349 from pathlib import Path 

1350 

1351 from .registry import CREATOR_MAP 

1352 from .views.utils import wrap_with_checksum 

1353 

1354 p = Path(output) 

1355 target = self.transitive_reduction() if transitive_reduction else self 

1356 

1357 if not embed_checksum: 

1358 return export_fnc(target, p, **kwargs) 

1359 

1360 # To embed checksum, we need to capture the output string first. 

1361 create_fnc = CREATOR_MAP.get(export_fnc) 

1362 

1363 if not create_fnc: 

1364 # Fallback: export normally if we can't find a string-generating version 

1365 export_name = getattr(export_fnc, "__name__", str(export_fnc)) 

1366 logger.warning( 

1367 f"Could not find string-generating version of {export_name}. Exporting normally without checksum embedding." 

1368 ) 

1369 return export_fnc(target, p, **kwargs) 

1370 

1371 content = create_fnc(target, **kwargs) 

1372 checksum = target.checksum() 

1373 wrapped = wrap_with_checksum(content, checksum, p.suffix) 

1374 

1375 with open(p, "w+") as f: 

1376 f.write(wrapped)