Skip to content

Commit a1c75ee

Browse files
authored
Merge pull request #29 from structured-world/feat/#28-schema-helpers-traverse
feat(client): add get_labels, get_edge_types, traverse
2 parents 20c9ec0 + ae37106 commit a1c75ee

4 files changed

Lines changed: 524 additions & 1 deletion

File tree

coordinode/coordinode/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
AsyncCoordinodeClient,
2323
CoordinodeClient,
2424
EdgeResult,
25+
EdgeTypeInfo,
26+
LabelInfo,
2527
NodeResult,
28+
PropertyDefinitionInfo,
29+
TraverseResult,
2630
VectorResult,
2731
)
2832

@@ -36,4 +40,8 @@
3640
"NodeResult",
3741
"EdgeResult",
3842
"VectorResult",
43+
"LabelInfo",
44+
"EdgeTypeInfo",
45+
"PropertyDefinitionInfo",
46+
"TraverseResult",
3947
]

coordinode/coordinode/client.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,54 @@ def __repr__(self) -> str:
8585
return f"VectorResult(distance={self.distance:.4f}, node={self.node})"
8686

8787

88+
class PropertyDefinitionInfo:
89+
"""A property definition from the schema (name, type, required, unique)."""
90+
91+
def __init__(self, proto_def: Any) -> None:
92+
self.name: str = proto_def.name
93+
self.type: int = proto_def.type
94+
self.required: bool = proto_def.required
95+
self.unique: bool = proto_def.unique
96+
97+
def __repr__(self) -> str:
98+
return f"PropertyDefinitionInfo(name={self.name!r}, type={self.type}, required={self.required}, unique={self.unique})"
99+
100+
101+
class LabelInfo:
102+
"""A node label returned from the schema registry."""
103+
104+
def __init__(self, proto_label: Any) -> None:
105+
self.name: str = proto_label.name
106+
self.version: int = proto_label.version
107+
self.properties: list[PropertyDefinitionInfo] = [PropertyDefinitionInfo(p) for p in proto_label.properties]
108+
109+
def __repr__(self) -> str:
110+
return f"LabelInfo(name={self.name!r}, version={self.version}, properties={len(self.properties)})"
111+
112+
113+
class EdgeTypeInfo:
114+
"""An edge type returned from the schema registry."""
115+
116+
def __init__(self, proto_edge_type: Any) -> None:
117+
self.name: str = proto_edge_type.name
118+
self.version: int = proto_edge_type.version
119+
self.properties: list[PropertyDefinitionInfo] = [PropertyDefinitionInfo(p) for p in proto_edge_type.properties]
120+
121+
def __repr__(self) -> str:
122+
return f"EdgeTypeInfo(name={self.name!r}, version={self.version}, properties={len(self.properties)})"
123+
124+
125+
class TraverseResult:
126+
"""Result of a graph traversal: reached nodes and traversed edges."""
127+
128+
def __init__(self, proto_response: Any) -> None:
129+
self.nodes: list[NodeResult] = [NodeResult(n) for n in proto_response.nodes]
130+
self.edges: list[EdgeResult] = [EdgeResult(e) for e in proto_response.edges]
131+
132+
def __repr__(self) -> str:
133+
return f"TraverseResult(nodes={len(self.nodes)}, edges={len(self.edges)})"
134+
135+
88136
# ── Async client ─────────────────────────────────────────────────────────────
89137

90138

@@ -303,6 +351,72 @@ async def get_schema_text(self) -> str:
303351

304352
return "\n".join(lines)
305353

354+
async def get_labels(self) -> list[LabelInfo]:
355+
"""Return all node labels defined in the schema."""
356+
from coordinode._proto.coordinode.v1.graph.schema_pb2 import ListLabelsRequest # type: ignore[import]
357+
358+
resp = await self._schema_stub.ListLabels(ListLabelsRequest(), timeout=self._timeout)
359+
return [LabelInfo(label) for label in resp.labels]
360+
361+
async def get_edge_types(self) -> list[EdgeTypeInfo]:
362+
"""Return all edge types defined in the schema."""
363+
from coordinode._proto.coordinode.v1.graph.schema_pb2 import ListEdgeTypesRequest # type: ignore[import]
364+
365+
resp = await self._schema_stub.ListEdgeTypes(ListEdgeTypesRequest(), timeout=self._timeout)
366+
return [EdgeTypeInfo(et) for et in resp.edge_types]
367+
368+
async def traverse(
369+
self,
370+
start_node_id: int,
371+
edge_type: str,
372+
direction: str = "outbound",
373+
max_depth: int = 1,
374+
) -> TraverseResult:
375+
"""Traverse the graph from *start_node_id* following *edge_type* edges.
376+
377+
Args:
378+
start_node_id: ID of the node to start from.
379+
edge_type: Edge type label to follow (e.g. ``"KNOWS"``).
380+
direction: ``"outbound"`` (default), ``"inbound"``, or ``"both"``.
381+
max_depth: Maximum hop count (default 1).
382+
383+
Returns:
384+
:class:`TraverseResult` with ``nodes`` and ``edges`` lists.
385+
"""
386+
# Validate pure string/int inputs before importing proto stubs — ensures ValueError
387+
# is raised even when proto stubs have not been generated yet.
388+
# Type guards come first so that wrong types raise ValueError, not AttributeError/TypeError.
389+
if not isinstance(direction, str):
390+
raise ValueError(f"direction must be a str, got {type(direction).__name__!r}.")
391+
_valid_directions = {"outbound", "inbound", "both"}
392+
key = direction.lower()
393+
if key not in _valid_directions:
394+
raise ValueError(f"Invalid direction {direction!r}. Must be one of: 'outbound', 'inbound', 'both'.")
395+
# bool is a subclass of int in Python, so `isinstance(True, int)` is True — exclude it.
396+
if not isinstance(max_depth, int) or isinstance(max_depth, bool) or max_depth < 1:
397+
raise ValueError(f"max_depth must be an integer >= 1, got {max_depth!r}.")
398+
399+
from coordinode._proto.coordinode.v1.graph.graph_pb2 import ( # type: ignore[import]
400+
TraversalDirection,
401+
TraverseRequest,
402+
)
403+
404+
_direction_map = {
405+
"outbound": TraversalDirection.TRAVERSAL_DIRECTION_OUTBOUND,
406+
"inbound": TraversalDirection.TRAVERSAL_DIRECTION_INBOUND,
407+
"both": TraversalDirection.TRAVERSAL_DIRECTION_BOTH,
408+
}
409+
direction_value = _direction_map[key]
410+
411+
req = TraverseRequest(
412+
start_node_id=start_node_id,
413+
edge_type=edge_type,
414+
direction=direction_value,
415+
max_depth=max_depth,
416+
)
417+
resp = await self._graph_stub.Traverse(req, timeout=self._timeout)
418+
return TraverseResult(resp)
419+
306420
async def health(self) -> bool:
307421
from coordinode._proto.coordinode.v1.health.health_pb2 import ( # type: ignore[import]
308422
HealthCheckRequest,
@@ -422,6 +536,24 @@ def create_edge(
422536
def get_schema_text(self) -> str:
423537
return self._run(self._async.get_schema_text())
424538

539+
def get_labels(self) -> list[LabelInfo]:
540+
"""Return all node labels defined in the schema."""
541+
return self._run(self._async.get_labels())
542+
543+
def get_edge_types(self) -> list[EdgeTypeInfo]:
544+
"""Return all edge types defined in the schema."""
545+
return self._run(self._async.get_edge_types())
546+
547+
def traverse(
548+
self,
549+
start_node_id: int,
550+
edge_type: str,
551+
direction: str = "outbound",
552+
max_depth: int = 1,
553+
) -> TraverseResult:
554+
"""Traverse the graph from *start_node_id* following *edge_type* edges."""
555+
return self._run(self._async.traverse(start_node_id, edge_type, direction, max_depth))
556+
425557
def health(self) -> bool:
426558
return self._run(self._async.health())
427559

tests/integration/test_sdk.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import pytest
1414

15-
from coordinode import AsyncCoordinodeClient, CoordinodeClient
15+
from coordinode import AsyncCoordinodeClient, CoordinodeClient, EdgeTypeInfo, LabelInfo, TraverseResult
1616

1717
ADDR = os.environ.get("COORDINODE_ADDR", "localhost:7080")
1818

@@ -208,6 +208,132 @@ def test_get_schema_text(client):
208208
client.cypher("MATCH (n:SchemaTestLabel {tag: $tag}) DETACH DELETE n", params={"tag": tag})
209209

210210

211+
# ── get_labels / get_edge_types / traverse ────────────────────────────────────
212+
213+
214+
def test_get_labels_returns_list(client):
215+
"""get_labels() returns a non-empty list of LabelInfo after data is present."""
216+
tag = uid()
217+
label_name = f"GetLabelsTest{uid()}"
218+
client.cypher(f"CREATE (n:{label_name} {{tag: $tag}})", params={"tag": tag})
219+
try:
220+
labels = client.get_labels()
221+
assert isinstance(labels, list)
222+
assert len(labels) > 0
223+
assert all(isinstance(lbl, LabelInfo) for lbl in labels)
224+
names = [lbl.name for lbl in labels]
225+
assert label_name in names, f"{label_name} not in {names}"
226+
finally:
227+
client.cypher(f"MATCH (n:{label_name} {{tag: $tag}}) DETACH DELETE n", params={"tag": tag})
228+
229+
230+
def test_get_labels_has_property_definitions(client):
231+
"""LabelInfo.properties is a list (may be empty for schema-free labels)."""
232+
tag = uid()
233+
label_name = f"PropLabel{uid()}"
234+
client.cypher(f"CREATE (n:{label_name} {{tag: $tag}})", params={"tag": tag})
235+
try:
236+
labels = client.get_labels()
237+
found = next((lbl for lbl in labels if lbl.name == label_name), None)
238+
assert found is not None, f"{label_name} not returned by get_labels()"
239+
# Intentionally only check the type — CoordiNode is schema-free and may return
240+
# an empty properties list even when the node was created with properties.
241+
assert isinstance(found.properties, list)
242+
finally:
243+
client.cypher(f"MATCH (n:{label_name} {{tag: $tag}}) DETACH DELETE n", params={"tag": tag})
244+
245+
246+
def test_get_edge_types_returns_list(client):
247+
"""get_edge_types() returns a non-empty list of EdgeTypeInfo after data is present."""
248+
tag = uid()
249+
edge_type = f"GET_EDGE_TYPE_TEST_{uid()}".upper()
250+
client.cypher(
251+
f"CREATE (a:EdgeTypeTestNode {{tag: $tag}})-[:{edge_type}]->(b:EdgeTypeTestNode {{tag: $tag}})",
252+
params={"tag": tag},
253+
)
254+
try:
255+
edge_types = client.get_edge_types()
256+
assert isinstance(edge_types, list)
257+
assert len(edge_types) > 0
258+
assert all(isinstance(et, EdgeTypeInfo) for et in edge_types)
259+
type_names = [et.name for et in edge_types]
260+
assert edge_type in type_names, f"{edge_type} not in {type_names}"
261+
finally:
262+
client.cypher("MATCH (n:EdgeTypeTestNode {tag: $tag}) DETACH DELETE n", params={"tag": tag})
263+
264+
265+
def test_traverse_returns_neighbours(client):
266+
"""traverse() returns adjacent nodes reachable via the given edge type."""
267+
tag = uid()
268+
client.cypher(
269+
"CREATE (a:TraverseRPC {tag: $tag, role: 'hub'})-[:TRAVERSE_TEST]->(b:TraverseRPC {tag: $tag, role: 'leaf1'})",
270+
params={"tag": tag},
271+
)
272+
try:
273+
rows = client.cypher(
274+
"MATCH (a:TraverseRPC {tag: $tag, role: 'hub'}) RETURN a AS node_id",
275+
params={"tag": tag},
276+
)
277+
assert len(rows) >= 1, "hub node not found"
278+
start_id = rows[0]["node_id"]
279+
280+
# Fetch the leaf1 node ID so we can assert it specifically appears in the result.
281+
leaf_rows = client.cypher(
282+
"MATCH (b:TraverseRPC {tag: $tag, role: 'leaf1'}) RETURN b AS node_id",
283+
params={"tag": tag},
284+
)
285+
assert len(leaf_rows) >= 1, "leaf1 node not found"
286+
leaf1_id = leaf_rows[0]["node_id"]
287+
288+
result = client.traverse(start_id, "TRAVERSE_TEST", direction="outbound", max_depth=1)
289+
assert isinstance(result, TraverseResult)
290+
assert len(result.nodes) >= 1, "traverse() returned no neighbour nodes"
291+
node_ids = {n.id for n in result.nodes}
292+
assert leaf1_id in node_ids, f"traverse() did not return the expected leaf1 node ({leaf1_id}); got: {node_ids}"
293+
finally:
294+
client.cypher("MATCH (n:TraverseRPC {tag: $tag}) DETACH DELETE n", params={"tag": tag})
295+
296+
297+
@pytest.mark.xfail(
298+
strict=False,
299+
raises=AssertionError,
300+
# strict=False: XPASS is good news (server gained inbound support), not an error.
301+
# strict=True would break CI exactly when the server improves, which is undesirable.
302+
# The XPASS report in pytest output is the signal to remove this marker.
303+
# raises=AssertionError: narrows xfail to the known failure mode (empty result set →
304+
# assertion fails). Unexpected errors (gRPC RpcError, wrong enum, etc.) are NOT covered
305+
# and will still propagate as CI failures.
306+
reason="CoordiNode Traverse RPC does not yet support inbound direction — server returns empty result set",
307+
)
308+
def test_traverse_inbound_direction(client):
309+
"""traverse() with direction='inbound' reaches nodes that point TO start_id."""
310+
tag = uid()
311+
client.cypher(
312+
"CREATE (src:TraverseIn {tag: $tag})-[:INBOUND_TEST]->(dst:TraverseIn {tag: $tag})",
313+
params={"tag": tag},
314+
)
315+
try:
316+
# Capture both src and dst so that when the server gains inbound support
317+
# (XPASS), the assertion verifies the *correct* node was returned, not just any node.
318+
rows = client.cypher(
319+
"MATCH (src:TraverseIn {tag: $tag})-[:INBOUND_TEST]->(dst:TraverseIn {tag: $tag}) "
320+
"RETURN src AS src_id, dst AS dst_id",
321+
params={"tag": tag},
322+
)
323+
assert len(rows) >= 1
324+
src_id = rows[0]["src_id"]
325+
dst_id = rows[0]["dst_id"]
326+
result = client.traverse(dst_id, "INBOUND_TEST", direction="inbound", max_depth=1)
327+
assert isinstance(result, TraverseResult)
328+
assert len(result.nodes) >= 1, "inbound traverse returned no nodes"
329+
node_ids = {n.id for n in result.nodes}
330+
assert src_id in node_ids, (
331+
f"inbound traverse did not return the expected source node ({src_id}); got: {node_ids}"
332+
)
333+
finally:
334+
client.cypher("MATCH (n:TraverseIn {tag: $tag}) DETACH DELETE n", params={"tag": tag})
335+
336+
211337
# ── Hybrid search ─────────────────────────────────────────────────────────────
212338

213339

0 commit comments

Comments
 (0)