Skip to content

Commit 07fc078

Browse files
refactor: refactor graphgen to integrete orchestration engine
1 parent 379ba46 commit 07fc078

5 files changed

Lines changed: 111 additions & 62 deletions

File tree

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
read:
2-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3-
split:
4-
chunk_size: 1024 # chunk size for text splitting
5-
chunk_overlap: 100 # chunk overlap for text splitting
6-
search: # web search configuration
7-
enabled: false # whether to enable web search
8-
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10-
enabled: true
11-
quiz_samples: 2 # number of quiz samples to generate
12-
re_judge: false # whether to re-judge the existing quiz samples
13-
partition: # graph partition configuration
14-
method: ece # ece is a custom partition method based on comprehension loss
15-
method_params:
16-
max_units_per_community: 20 # max nodes and edges per community
17-
min_units_per_community: 5 # min nodes and edges per community
18-
max_tokens_per_community: 10240 # max tokens per community
19-
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
20-
generate:
21-
mode: aggregated # atomic, aggregated, multi_hop, cot, vqa
22-
data_format: ChatML # Alpaca, Sharegpt, ChatML
1+
pipeline:
2+
- name: insert
3+
params:
4+
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
chunk_size: 1024 # chunk size for text splitting
6+
chunk_overlap: 100 # chunk overlap for text splitting
7+
- name: quiz_and_judge
8+
params:
9+
quiz_samples: 2 # number of quiz samples to generate
10+
re_judge: false # whether to re-judge the existing quiz samples
11+
- name: partition
12+
deps: [insert, quiz_and_judge] # ece depends on both insert and quiz_and_judge steps
13+
params:
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
method_params:
16+
max_units_per_community: 20 # max nodes and edges per community
17+
min_units_per_community: 5 # min nodes and edges per community
18+
max_tokens_per_community: 10240 # max tokens per community
19+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
20+
- name: generate
21+
params:
22+
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
23+
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/engine.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
"""
44

55
import threading
6+
import traceback
7+
from functools import wraps
68
from typing import Any, Callable, List
79

10+
from graphgen.utils import logger
11+
812

913
class Context(dict):
1014
_lock = threading.Lock()
@@ -25,9 +29,16 @@ def __init__(
2529
self.name, self.deps, self.func = name, deps, func
2630

2731

28-
def op(name: str, deps: List[str] = None):
29-
def decorator(f: Callable[["OpNode", Context], Any]):
30-
return OpNode(name, deps or [], f)
32+
def op(name: str, deps=None):
33+
deps = deps or []
34+
35+
def decorator(func):
36+
@wraps(func)
37+
def _wrapper(*args, **kwargs):
38+
return func(*args, **kwargs)
39+
40+
_wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx))
41+
return _wrapper
3142

3243
return decorator
3344

@@ -73,7 +84,8 @@ def _exec(n: str):
7384
try:
7485
name2op[n].func(name2op[n], ctx)
7586
except Exception as e: # pylint: disable=broad-except
76-
exc[n] = e
87+
logger.error("Operation %s failed: %s", n, e)
88+
exc[n] = traceback.format_exc()
7789
done[n].set()
7890

7991
ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
@@ -82,4 +94,28 @@ def _exec(n: str):
8294
for t in ts:
8395
t.join()
8496
if exc:
85-
raise RuntimeError(f"Some operations failed: {exc}")
97+
raise RuntimeError(
98+
"Some operations failed:\n"
99+
+ "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
100+
)
101+
102+
103+
def collect_ops(config: dict, graph_gen) -> List[OpNode]:
104+
"""
105+
build operation nodes from yaml config
106+
:param config
107+
:param graph_gen
108+
"""
109+
ops: List[OpNode] = []
110+
for stage in config["pipeline"]:
111+
name = stage["name"]
112+
method = getattr(graph_gen, name)
113+
op_node = method.op_node
114+
115+
# if there are runtime dependencies, override them
116+
runtime_deps = stage.get("deps", op_node.deps)
117+
op_node.deps = runtime_deps
118+
119+
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params"))
120+
ops.append(op_node)
121+
return ops

graphgen/graphgen.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from graphgen.bases import BaseLLMWrapper
99
from graphgen.bases.base_storage import StorageNameSpace
1010
from graphgen.bases.datatypes import Chunk
11+
from graphgen.engine import op
1112
from graphgen.models import (
1213
JsonKVStorage,
1314
JsonListStorage,
@@ -69,6 +70,9 @@ def __init__(
6970
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7071
self.working_dir, namespace="rephrase"
7172
)
73+
self.partition_storage: JsonListStorage = JsonListStorage(
74+
self.working_dir, namespace="partition"
75+
)
7276
self.qa_storage: JsonListStorage = JsonListStorage(
7377
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
7478
namespace="qa",
@@ -77,13 +81,14 @@ def __init__(
7781
# webui
7882
self.progress_bar: gr.Progress = progress_bar
7983

84+
@op("insert", deps=[])
8085
@async_to_sync_method
81-
async def insert(self, read_config: Dict, split_config: Dict):
86+
async def insert(self, insert_config: Dict):
8287
"""
8388
insert chunks into the graph
8489
"""
8590
# Step 1: Read files
86-
data = read_files(read_config["input_file"], self.working_dir)
91+
data = read_files(insert_config["input_file"], self.working_dir)
8792
if len(data) == 0:
8893
logger.warning("No data to process")
8994
return
@@ -102,8 +107,8 @@ async def insert(self, read_config: Dict, split_config: Dict):
102107

103108
inserting_chunks = await chunk_documents(
104109
new_docs,
105-
split_config["chunk_size"],
106-
split_config["chunk_overlap"],
110+
insert_config["chunk_size"],
111+
insert_config["chunk_overlap"],
107112
self.tokenizer_instance,
108113
self.progress_bar,
109114
)
@@ -148,6 +153,7 @@ async def _insert_done(self):
148153
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
149154
await asyncio.gather(*tasks)
150155

156+
@op("search", deps=["insert"])
151157
@async_to_sync_method
152158
async def search(self, search_config: Dict):
153159
logger.info(
@@ -183,13 +189,13 @@ async def search(self, search_config: Dict):
183189
# TODO: fix insert after search
184190
await self.insert()
185191

192+
@op("quiz_and_judge", deps=["insert"])
186193
@async_to_sync_method
187194
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
188-
if quiz_and_judge_config is None or not quiz_and_judge_config.get(
189-
"enabled", False
190-
):
191-
logger.warning("Quiz and Judge is not used in this pipeline.")
192-
return
195+
logger.warning(
196+
"Quiz and Judge operation needs trainee LLM client."
197+
" Make sure to provide one."
198+
)
193199
max_samples = quiz_and_judge_config["quiz_samples"]
194200
await quiz(
195201
self.synthesizer_llm_client,
@@ -222,15 +228,26 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
222228
logger.info("Restarting synthesizer LLM client.")
223229
self.synthesizer_llm_client.restart()
224230

231+
@op("partition", deps=["insert"])
225232
@async_to_sync_method
226-
async def generate(self, partition_config: Dict, generate_config: Dict):
227-
# Step 1: partition the graph
233+
async def partition(self, partition_config: Dict):
228234
batches = await partition_kg(
229235
self.graph_storage,
230236
self.chunks_storage,
231237
self.tokenizer_instance,
232238
partition_config,
233239
)
240+
await self.partition_storage.upsert(batches)
241+
return batches
242+
243+
@op("generate", deps=["insert", "partition"])
244+
@async_to_sync_method
245+
async def generate(self, generate_config: Dict):
246+
247+
batches = self.partition_storage.data
248+
if not batches:
249+
logger.warning("No partitions found for QA generation")
250+
return
234251

235252
# Step 2: generate QA pairs
236253
results = await generate_qas(
@@ -258,3 +275,6 @@ async def clear(self):
258275
await self.qa_storage.drop()
259276

260277
logger.info("All caches are cleared")
278+
279+
# TODO: add data filtering step here in the future
280+
# graph_gen.filter(filter_config=config["filter"])

graphgen/operators/generate/generate_qas.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@ async def generate_qas(
2929
:param progress_bar
3030
:return: QA pairs
3131
"""
32-
mode = generation_config["mode"]
33-
logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
32+
method = generation_config["method"]
33+
logger.info("[Generation] mode: %s, batches: %d", method, len(batches))
3434

35-
if mode == "atomic":
35+
if method == "atomic":
3636
generator = AtomicGenerator(llm_client)
37-
elif mode == "aggregated":
37+
elif method == "aggregated":
3838
generator = AggregatedGenerator(llm_client)
39-
elif mode == "multi_hop":
39+
elif method == "multi_hop":
4040
generator = MultiHopGenerator(llm_client)
41-
elif mode == "cot":
41+
elif method == "cot":
4242
generator = CoTGenerator(llm_client)
43-
elif mode in ["vqa"]:
43+
elif method in ["vqa"]:
4444
generator = VQAGenerator(llm_client)
4545
else:
46-
raise ValueError(f"Unsupported generation mode: {mode}")
46+
raise ValueError(f"Unsupported generation mode: {method}")
4747

4848
results = await run_concurrent(
4949
generator.generate,

graphgen/run.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import yaml
77
from dotenv import load_dotenv
88

9+
from graphgen.engine import Context, Engine, collect_ops
910
from graphgen.graphgen import GraphGen
1011
from graphgen.utils import logger, set_logger
1112

@@ -50,38 +51,29 @@ def main():
5051
with open(args.config_file, "r", encoding="utf-8") as f:
5152
config = yaml.load(f, Loader=yaml.FullLoader)
5253

53-
mode = config["generate"]["mode"]
5454
unique_id = int(time.time())
5555

5656
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
5757
set_working_dir(output_path)
5858

5959
set_logger(
60-
os.path.join(output_path, f"{unique_id}_{mode}.log"),
60+
os.path.join(output_path, f"{unique_id}.log"),
6161
if_stream=True,
6262
)
6363
logger.info(
6464
"GraphGen with unique ID %s logging to %s",
6565
unique_id,
66-
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
66+
os.path.join(working_dir, f"{unique_id}.log"),
6767
)
6868

6969
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
7070

71-
graph_gen.insert(read_config=config["read"], split_config=config["split"])
71+
# share context between different steps
72+
ctx = Context(config=config, graph_gen=graph_gen)
73+
ops = collect_ops(config, graph_gen)
7274

73-
graph_gen.search(search_config=config["search"])
74-
75-
if config.get("quiz_and_judge", {}).get("enabled"):
76-
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
77-
78-
# TODO: add data filtering step here in the future
79-
# graph_gen.filter(filter_config=config["filter"])
80-
81-
graph_gen.generate(
82-
partition_config=config["partition"],
83-
generate_config=config["generate"],
84-
)
75+
# run operations
76+
Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
8577

8678
save_config(os.path.join(output_path, "config.yaml"), config)
8779
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

0 commit comments

Comments
 (0)