88from graphgen .bases import BaseLLMWrapper
99from graphgen .bases .base_storage import StorageNameSpace
1010from graphgen .bases .datatypes import Chunk
11+ from graphgen .engine import op
1112from graphgen .models import (
1213 JsonKVStorage ,
1314 JsonListStorage ,
@@ -69,6 +70,9 @@ def __init__(
6970 self .rephrase_storage : JsonKVStorage = JsonKVStorage (
7071 self .working_dir , namespace = "rephrase"
7172 )
73+ self .partition_storage : JsonListStorage = JsonListStorage (
74+ self .working_dir , namespace = "partition"
75+ )
7276 self .qa_storage : JsonListStorage = JsonListStorage (
7377 os .path .join (self .working_dir , "data" , "graphgen" , f"{ self .unique_id } " ),
7478 namespace = "qa" ,
@@ -77,13 +81,14 @@ def __init__(
7781 # webui
7882 self .progress_bar : gr .Progress = progress_bar
7983
84+ @op ("insert" , deps = [])
8085 @async_to_sync_method
81- async def insert (self , read_config : Dict , split_config : Dict ):
86+ async def insert (self , insert_config : Dict ):
8287 """
8388 insert chunks into the graph
8489 """
8590 # Step 1: Read files
86- data = read_files (read_config ["input_file" ], self .working_dir )
91+ data = read_files (insert_config ["input_file" ], self .working_dir )
8792 if len (data ) == 0 :
8893 logger .warning ("No data to process" )
8994 return
@@ -102,8 +107,8 @@ async def insert(self, read_config: Dict, split_config: Dict):
102107
103108 inserting_chunks = await chunk_documents (
104109 new_docs ,
105- split_config ["chunk_size" ],
106- split_config ["chunk_overlap" ],
110+ insert_config ["chunk_size" ],
111+ insert_config ["chunk_overlap" ],
107112 self .tokenizer_instance ,
108113 self .progress_bar ,
109114 )
@@ -148,6 +153,7 @@ async def _insert_done(self):
148153 tasks .append (cast (StorageNameSpace , storage_instance ).index_done_callback ())
149154 await asyncio .gather (* tasks )
150155
156+ @op ("search" , deps = ["insert" ])
151157 @async_to_sync_method
152158 async def search (self , search_config : Dict ):
153159 logger .info (
@@ -183,13 +189,13 @@ async def search(self, search_config: Dict):
183189 # TODO: fix insert after search
184190 await self .insert ()
185191
192+ @op ("quiz_and_judge" , deps = ["insert" ])
186193 @async_to_sync_method
187194 async def quiz_and_judge (self , quiz_and_judge_config : Dict ):
188- if quiz_and_judge_config is None or not quiz_and_judge_config .get (
189- "enabled" , False
190- ):
191- logger .warning ("Quiz and Judge is not used in this pipeline." )
192- return
195+ logger .warning (
196+ "Quiz and Judge operation needs trainee LLM client."
197+ " Make sure to provide one."
198+ )
193199 max_samples = quiz_and_judge_config ["quiz_samples" ]
194200 await quiz (
195201 self .synthesizer_llm_client ,
@@ -222,15 +228,26 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
222228 logger .info ("Restarting synthesizer LLM client." )
223229 self .synthesizer_llm_client .restart ()
224230
231+ @op ("partition" , deps = ["insert" ])
225232 @async_to_sync_method
226- async def generate (self , partition_config : Dict , generate_config : Dict ):
227- # Step 1: partition the graph
233+ async def partition (self , partition_config : Dict ):
228234 batches = await partition_kg (
229235 self .graph_storage ,
230236 self .chunks_storage ,
231237 self .tokenizer_instance ,
232238 partition_config ,
233239 )
240+ await self .partition_storage .upsert (batches )
241+ return batches
242+
243+ @op ("generate" , deps = ["insert" , "partition" ])
244+ @async_to_sync_method
245+ async def generate (self , generate_config : Dict ):
246+
247+ batches = self .partition_storage .data
248+ if not batches :
249+ logger .warning ("No partitions found for QA generation" )
250+ return
234251
235252 # Step 2: generate QA pairs
236253 results = await generate_qas (
@@ -258,3 +275,6 @@ async def clear(self):
258275 await self .qa_storage .drop ()
259276
260277 logger .info ("All caches are cleared" )
278+
279+ # TODO: add data filtering step here in the future
280+ # graph_gen.filter(filter_config=config["filter"])
0 commit comments