Skip to content

Commit 0c64e8d

Browse files
authored
Merge pull request #611 from FalkorDB/async-endpoints
Migrate FastAPI endpoints from sync to async
2 parents d4daf00 + 24eb501 commit 0c64e8d

9 files changed

Lines changed: 521 additions & 104 deletions

File tree

api/auto_complete.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
from .graph import Graph
1+
from .graph import Graph, AsyncGraphQuery
2+
23

34
def prefix_search(repo: str, prefix: str) -> str:
45
""" Returns a list of all entities in the repository that start with the given prefix. """
56
g = Graph(repo)
67
return g.prefix_search(prefix)
8+
9+
10+
async def async_prefix_search(repo: str, prefix: str) -> list:
11+
"""Async version of prefix_search using AsyncGraphQuery."""
12+
g = AsyncGraphQuery(repo)
13+
try:
14+
return await g.prefix_search(prefix)
15+
finally:
16+
await g.close()

api/git_utils/git_graph.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import logging
33
from falkordb import FalkorDB, Node
4+
from falkordb.asyncio import FalkorDB as AsyncFalkorDB
45
from typing import List, Optional
56

67
from pygit2 import Commit
@@ -176,3 +177,32 @@ def get_child_transitions(self, child: str, parent: str) -> tuple[list[str], lis
176177

177178
return (res[0][0], res[0][1])
178179

180+
181+
class AsyncGitGraph:
182+
"""Async read-only git graph for endpoint use."""
183+
184+
def __init__(self, name: str):
185+
self.db = AsyncFalkorDB(
186+
host=os.getenv('FALKORDB_HOST', 'localhost'),
187+
port=int(os.getenv('FALKORDB_PORT', 6379)),
188+
username=os.getenv('FALKORDB_USERNAME', None),
189+
password=os.getenv('FALKORDB_PASSWORD', None),
190+
)
191+
self.g = self.db.select_graph(name)
192+
193+
def _commit_from_node(self, node: Node) -> dict:
194+
return {
195+
'hash': node.properties['hash'],
196+
'date': node.properties['date'],
197+
'author': node.properties['author'],
198+
'message': node.properties['message'],
199+
}
200+
201+
async def list_commits(self) -> List[dict]:
202+
q = "MATCH (c:Commit) RETURN c ORDER BY c.date"
203+
result_set = (await self.g.query(q)).result_set
204+
return [self._commit_from_node(row[0]) for row in result_set]
205+
206+
async def close(self) -> None:
207+
await self.db.aclose()
208+

api/graph.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .entities import *
44
from typing import Optional
55
from falkordb import FalkorDB, Path, Node, QueryResult
6+
from falkordb.asyncio import FalkorDB as AsyncFalkorDB
67

78
# Configure the logger
89
import logging
@@ -627,3 +628,144 @@ def unreachable_entities(self, lbl: Optional[str], rel: Optional[str]) -> list[d
627628

628629
return unreachables
629630

631+
632+
# ---------------------------------------------------------------------------
633+
# Async helpers and read-only async graph wrapper
634+
# ---------------------------------------------------------------------------
635+
636+
def _async_db() -> AsyncFalkorDB:
637+
"""Create an async FalkorDB connection using environment config."""
638+
return AsyncFalkorDB(
639+
host=os.getenv('FALKORDB_HOST', 'localhost'),
640+
port=int(os.getenv('FALKORDB_PORT', 6379)),
641+
username=os.getenv('FALKORDB_USERNAME', None),
642+
password=os.getenv('FALKORDB_PASSWORD', None),
643+
)
644+
645+
646+
async def async_graph_exists(name: str) -> bool:
647+
db = _async_db()
648+
try:
649+
graphs = await db.list_graphs()
650+
return name in graphs
651+
finally:
652+
await db.aclose()
653+
654+
655+
async def async_get_repos() -> list[str]:
656+
"""List processed repositories (async version)."""
657+
db = _async_db()
658+
try:
659+
graphs = await db.list_graphs()
660+
return [g for g in graphs if not (g.endswith('_git') or g.endswith('_schema'))]
661+
finally:
662+
await db.aclose()
663+
664+
665+
class AsyncGraphQuery:
666+
"""Read-only async wrapper for endpoint use.
667+
668+
Uses falkordb.asyncio under the hood. No index creation or backlog —
669+
indexes already exist from the sync Graph used during analysis.
670+
"""
671+
672+
def __init__(self, name: str) -> None:
673+
self.name = name
674+
self.db = _async_db()
675+
self.g = self.db.select_graph(name)
676+
677+
async def graph_exists(self) -> bool:
678+
"""Check if this graph exists, reusing the current connection."""
679+
graphs = await self.db.list_graphs()
680+
return self.name in graphs
681+
682+
async def _query(self, q: str, params: Optional[dict] = None):
683+
return await self.g.query(q, params)
684+
685+
async def get_sub_graph(self, limit: int) -> dict:
686+
q = """MATCH (src)
687+
OPTIONAL MATCH (src)-[e]->(dest)
688+
RETURN src, e, dest
689+
LIMIT $limit"""
690+
691+
sub_graph = {'nodes': [], 'edges': []}
692+
result_set = (await self._query(q, {'limit': limit})).result_set
693+
for row in result_set:
694+
src = row[0]
695+
e = row[1]
696+
dest = row[2]
697+
sub_graph['nodes'].append(encode_node(src))
698+
if e is not None:
699+
sub_graph['edges'].append(encode_edge(e))
700+
sub_graph['nodes'].append(encode_node(dest))
701+
return sub_graph
702+
703+
async def get_neighbors(self, node_ids: list[int], rel: Optional[str] = None, lbl: Optional[str] = None) -> dict:
704+
if not all(isinstance(node_id, int) for node_id in node_ids):
705+
raise ValueError("node_ids must be an integer list")
706+
707+
rel_query = f":{rel}" if rel else ""
708+
lbl_query = f":{lbl}" if lbl else ""
709+
710+
query = f"""
711+
MATCH (n)-[e{rel_query}]->(dest{lbl_query})
712+
WHERE ID(n) IN $node_ids
713+
RETURN e, dest
714+
"""
715+
716+
neighbors = {'nodes': [], 'edges': []}
717+
try:
718+
result_set = (await self._query(query, {'node_ids': node_ids})).result_set
719+
for edge, destination_node in result_set:
720+
neighbors['nodes'].append(encode_node(destination_node))
721+
neighbors['edges'].append(encode_edge(edge))
722+
return neighbors
723+
except Exception as e:
724+
logging.error(f"Error fetching neighbors for node {node_ids}: {e}")
725+
return {'nodes': [], 'edges': []}
726+
727+
async def prefix_search(self, prefix: str) -> list:
728+
search_prefix = f"{prefix}*"
729+
query = """
730+
CALL db.idx.fulltext.queryNodes('Searchable', $prefix)
731+
YIELD node
732+
WITH node
733+
RETURN node
734+
LIMIT 10
735+
"""
736+
result_set = (await self._query(query, {'prefix': search_prefix})).result_set
737+
return [encode_node(row[0]) for row in result_set]
738+
739+
async def find_paths(self, src: int, dest: int) -> list:
740+
q = """MATCH (src), (dest)
741+
WHERE ID(src) = $src_id AND ID(dest) = $dest_id
742+
WITH src, dest
743+
MATCH p = (src)-[:CALLS*]->(dest)
744+
RETURN p
745+
"""
746+
result_set = (await self._query(q, {'src_id': src, 'dest_id': dest})).result_set
747+
paths = []
748+
for row in result_set:
749+
path = []
750+
p = row[0]
751+
nodes = p.nodes()
752+
edges = p.edges()
753+
for n, e in zip(nodes, edges):
754+
path.append(encode_node(n))
755+
path.append(encode_edge(e))
756+
path.append(encode_node(nodes[-1]))
757+
paths.append(path)
758+
return paths
759+
760+
async def stats(self) -> dict:
761+
q = "MATCH (n) RETURN count(n)"
762+
node_count = (await self._query(q)).result_set[0][0]
763+
764+
q = "MATCH ()-[e]->() RETURN count(e)"
765+
edge_count = (await self._query(q)).result_set[0][0]
766+
767+
return {'node_count': node_count, 'edge_count': edge_count}
768+
769+
async def close(self) -> None:
770+
await self.db.aclose()
771+

0 commit comments

Comments
 (0)