Skip to content

Commit 157f0b0

Browse files
fix: reload graph before quizzing
1 parent ee0639d commit 157f0b0

7 files changed

Lines changed: 93 additions & 3 deletions

File tree

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .json_storage import JsonKVStorage, JsonListStorage
2-
from .networkx_storage import NetworkXStorage
1+
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
2+
from graphgen.models.storage.kv.json_storage import JsonKVStorage
3+
34
from .rocksdb_cache import RocksDBCache

graphgen/models/storage/graph/__init__.py

Whitespace-only changes.

graphgen/models/storage/graph/networkx_storage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,9 @@ def clear(self):
170170
"""
171171
self._graph.clear()
172172
logger.info("Graph %s cleared.", self.namespace)
173+
174+
def reload(self):
175+
"""
176+
Reload the graph from the GraphML file.
177+
"""
178+
self.__post_init__()

graphgen/models/storage/kv/__init__.py

Whitespace-only changes.

graphgen/models/storage/kv/json_storage.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,7 @@ def upsert(self, data: dict):
5454
def drop(self):
5555
if self._data:
5656
self._data.clear()
57+
58+
def reload(self):
59+
self._data = load_json(self._file_name) or {}
60+
logger.info("Reload KV %s with %d data", self.namespace, len(self._data))
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
from dataclasses import dataclass
3+
from typing import Any, Dict, List, Set
4+
5+
# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it
6+
# pylint: disable=no-name-in-module
7+
from rocksdict import Rdict
8+
9+
from graphgen.bases.base_storage import BaseKVStorage
10+
from graphgen.utils import logger
11+
12+
13+
@dataclass
14+
class RocksDBKVStorage(BaseKVStorage):
15+
_db: Rdict = None
16+
_db_path: str = None
17+
18+
def __post_init__(self):
19+
self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db")
20+
self._db = Rdict(self._db_path)
21+
logger.info("Load KV (RocksDB) %s at %s", self.namespace, self._db_path)
22+
23+
@property
24+
def data(self):
25+
return self._db
26+
27+
def all_keys(self) -> List[str]:
28+
return list(self._db.keys())
29+
30+
def index_done_callback(self):
31+
self._db.flush()
32+
logger.info("RocksDB flushed for %s", self.namespace)
33+
34+
def get_by_id(self, id: str) -> Any:
35+
return self._db.get(id, None)
36+
37+
def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]:
38+
result = []
39+
for index in ids:
40+
item = self._db.get(index, None)
41+
if item is None:
42+
result.append(None)
43+
continue
44+
45+
if fields is None:
46+
result.append(item)
47+
else:
48+
result.append({k: v for k, v in item.items() if k in fields})
49+
return result
50+
51+
def get_all(self) -> Dict[str, Dict]:
52+
return dict(self._db)
53+
54+
def filter_keys(self, data: List[str]) -> Set[str]:
55+
return {s for s in data if s not in self._db}
56+
57+
def upsert(self, data: Dict[str, Any]):
58+
left_data = {}
59+
for k, v in data.items():
60+
if k not in self._db:
61+
left_data[k] = v
62+
63+
if left_data:
64+
for k, v in left_data.items():
65+
self._db[k] = v
66+
67+
# if left_data is very large, it is recommended to use self._db.write_batch() for optimization
68+
69+
return left_data
70+
71+
def drop(self):
72+
self._db.close()
73+
Rdict.destroy(self._db_path)
74+
self._db = Rdict(self._db_path)
75+
logger.info("Dropped RocksDB %s", self.namespace)
76+
77+
def close(self):
78+
if self._db:
79+
self._db.close()

graphgen/operators/quiz/quiz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
2727
# this operator does not consume any batch data
2828
# but for compatibility we keep the interface
2929
_ = batch.to_dict(orient="records")
30-
30+
self.graph_storage.reload()
3131
yield from self.quiz()
3232

3333
async def _process_single_quiz(self, item: str) -> dict | None:

0 commit comments

Comments
 (0)