|
3 | 3 | from .entities import * |
4 | 4 | from typing import Optional |
5 | 5 | from falkordb import FalkorDB, Path, Node, QueryResult |
| 6 | +from falkordb.asyncio import FalkorDB as AsyncFalkorDB |
6 | 7 |
|
7 | 8 | # Configure the logger |
8 | 9 | import logging |
@@ -627,3 +628,144 @@ def unreachable_entities(self, lbl: Optional[str], rel: Optional[str]) -> list[d |
627 | 628 |
|
628 | 629 | return unreachables |
629 | 630 |
|
| 631 | + |
| 632 | +# --------------------------------------------------------------------------- |
| 633 | +# Async helpers and read-only async graph wrapper |
| 634 | +# --------------------------------------------------------------------------- |
| 635 | + |
| 636 | +def _async_db() -> AsyncFalkorDB: |
| 637 | + """Create an async FalkorDB connection using environment config.""" |
| 638 | + return AsyncFalkorDB( |
| 639 | + host=os.getenv('FALKORDB_HOST', 'localhost'), |
| 640 | + port=int(os.getenv('FALKORDB_PORT', 6379)), |
| 641 | + username=os.getenv('FALKORDB_USERNAME', None), |
| 642 | + password=os.getenv('FALKORDB_PASSWORD', None), |
| 643 | + ) |
| 644 | + |
| 645 | + |
| 646 | +async def async_graph_exists(name: str) -> bool: |
| 647 | + db = _async_db() |
| 648 | + try: |
| 649 | + graphs = await db.list_graphs() |
| 650 | + return name in graphs |
| 651 | + finally: |
| 652 | + await db.aclose() |
| 653 | + |
| 654 | + |
| 655 | +async def async_get_repos() -> list[str]: |
| 656 | + """List processed repositories (async version).""" |
| 657 | + db = _async_db() |
| 658 | + try: |
| 659 | + graphs = await db.list_graphs() |
| 660 | + return [g for g in graphs if not (g.endswith('_git') or g.endswith('_schema'))] |
| 661 | + finally: |
| 662 | + await db.aclose() |
| 663 | + |
| 664 | + |
| 665 | +class AsyncGraphQuery: |
| 666 | + """Read-only async wrapper for endpoint use. |
| 667 | +
|
| 668 | + Uses falkordb.asyncio under the hood. No index creation or backlog — |
| 669 | + indexes already exist from the sync Graph used during analysis. |
| 670 | + """ |
| 671 | + |
| 672 | + def __init__(self, name: str) -> None: |
| 673 | + self.name = name |
| 674 | + self.db = _async_db() |
| 675 | + self.g = self.db.select_graph(name) |
| 676 | + |
| 677 | + async def graph_exists(self) -> bool: |
| 678 | + """Check if this graph exists, reusing the current connection.""" |
| 679 | + graphs = await self.db.list_graphs() |
| 680 | + return self.name in graphs |
| 681 | + |
| 682 | + async def _query(self, q: str, params: Optional[dict] = None): |
| 683 | + return await self.g.query(q, params) |
| 684 | + |
| 685 | + async def get_sub_graph(self, limit: int) -> dict: |
| 686 | + q = """MATCH (src) |
| 687 | + OPTIONAL MATCH (src)-[e]->(dest) |
| 688 | + RETURN src, e, dest |
| 689 | + LIMIT $limit""" |
| 690 | + |
| 691 | + sub_graph = {'nodes': [], 'edges': []} |
| 692 | + result_set = (await self._query(q, {'limit': limit})).result_set |
| 693 | + for row in result_set: |
| 694 | + src = row[0] |
| 695 | + e = row[1] |
| 696 | + dest = row[2] |
| 697 | + sub_graph['nodes'].append(encode_node(src)) |
| 698 | + if e is not None: |
| 699 | + sub_graph['edges'].append(encode_edge(e)) |
| 700 | + sub_graph['nodes'].append(encode_node(dest)) |
| 701 | + return sub_graph |
| 702 | + |
| 703 | + async def get_neighbors(self, node_ids: list[int], rel: Optional[str] = None, lbl: Optional[str] = None) -> dict: |
| 704 | + if not all(isinstance(node_id, int) for node_id in node_ids): |
| 705 | + raise ValueError("node_ids must be an integer list") |
| 706 | + |
| 707 | + rel_query = f":{rel}" if rel else "" |
| 708 | + lbl_query = f":{lbl}" if lbl else "" |
| 709 | + |
| 710 | + query = f""" |
| 711 | + MATCH (n)-[e{rel_query}]->(dest{lbl_query}) |
| 712 | + WHERE ID(n) IN $node_ids |
| 713 | + RETURN e, dest |
| 714 | + """ |
| 715 | + |
| 716 | + neighbors = {'nodes': [], 'edges': []} |
| 717 | + try: |
| 718 | + result_set = (await self._query(query, {'node_ids': node_ids})).result_set |
| 719 | + for edge, destination_node in result_set: |
| 720 | + neighbors['nodes'].append(encode_node(destination_node)) |
| 721 | + neighbors['edges'].append(encode_edge(edge)) |
| 722 | + return neighbors |
| 723 | + except Exception as e: |
| 724 | + logging.error(f"Error fetching neighbors for node {node_ids}: {e}") |
| 725 | + return {'nodes': [], 'edges': []} |
| 726 | + |
| 727 | + async def prefix_search(self, prefix: str) -> list: |
| 728 | + search_prefix = f"{prefix}*" |
| 729 | + query = """ |
| 730 | + CALL db.idx.fulltext.queryNodes('Searchable', $prefix) |
| 731 | + YIELD node |
| 732 | + WITH node |
| 733 | + RETURN node |
| 734 | + LIMIT 10 |
| 735 | + """ |
| 736 | + result_set = (await self._query(query, {'prefix': search_prefix})).result_set |
| 737 | + return [encode_node(row[0]) for row in result_set] |
| 738 | + |
| 739 | + async def find_paths(self, src: int, dest: int) -> list: |
| 740 | + q = """MATCH (src), (dest) |
| 741 | + WHERE ID(src) = $src_id AND ID(dest) = $dest_id |
| 742 | + WITH src, dest |
| 743 | + MATCH p = (src)-[:CALLS*]->(dest) |
| 744 | + RETURN p |
| 745 | + """ |
| 746 | + result_set = (await self._query(q, {'src_id': src, 'dest_id': dest})).result_set |
| 747 | + paths = [] |
| 748 | + for row in result_set: |
| 749 | + path = [] |
| 750 | + p = row[0] |
| 751 | + nodes = p.nodes() |
| 752 | + edges = p.edges() |
| 753 | + for n, e in zip(nodes, edges): |
| 754 | + path.append(encode_node(n)) |
| 755 | + path.append(encode_edge(e)) |
| 756 | + path.append(encode_node(nodes[-1])) |
| 757 | + paths.append(path) |
| 758 | + return paths |
| 759 | + |
| 760 | + async def stats(self) -> dict: |
| 761 | + q = "MATCH (n) RETURN count(n)" |
| 762 | + node_count = (await self._query(q)).result_set[0][0] |
| 763 | + |
| 764 | + q = "MATCH ()-[e]->() RETURN count(e)" |
| 765 | + edge_count = (await self._query(q)).result_set[0][0] |
| 766 | + |
| 767 | + return {'node_count': node_count, 'edge_count': edge_count} |
| 768 | + |
| 769 | + async def close(self) -> None: |
| 770 | + await self.db.aclose() |
| 771 | + |
0 commit comments