Skip to content

Commit 8a2a1fc

Browse files
feat: add _meta.json to record processed chunks
1 parent ae31db9 commit 8a2a1fc

11 files changed

Lines changed: 87 additions & 42 deletions

graphgen/configs/aggregated_config.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: quiz_and_judge
911
params:
1012
quiz_samples: 2 # number of quiz samples to generate
1113
re_judge: false # whether to re-judge the existing quiz samples
1214

1315
- name: partition
14-
deps: [insert, quiz_and_judge] # ece depends on both insert and quiz_and_judge steps
16+
deps: [quiz_and_judge] # ece depends on quiz_and_judge steps
1517
params:
1618
method: ece # ece is a custom partition method based on comprehension loss
1719
method_params:

graphgen/configs/atomic_config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
7+
8+
- name: build_kg
9+
710
- name: partition
811
params:
912
method: dfs # partition method, support: dfs, bfs, ece, leiden

graphgen/configs/cot_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: leiden # leiden is a partitioner detection algorithm

graphgen/configs/multi_hop_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: ece # ece is a custom partition method based on comprehension loss

graphgen/configs/vqa_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
44
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55
chunk_size: 1024 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

8+
- name: build_kg
9+
810
- name: partition
911
params:
1012
method: anchor_bfs # partition method

graphgen/engine.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def collect_ops(config: dict, graph_gen) -> List[OpNode]:
113113
runtime_deps = stage.get("deps", op_node.deps)
114114
op_node.deps = runtime_deps
115115

116-
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params"))
116+
if "params" in stage:
117+
op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {}))
118+
else:
119+
op_node.func = lambda self, ctx, m=method: m()
117120
ops.append(op_node)
118121
return ops

graphgen/graphgen.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
import asyncio
21
import os
32
import time
4-
from typing import Dict, cast
3+
from typing import Dict
54

65
import gradio as gr
76

87
from graphgen.bases import BaseLLMWrapper
9-
from graphgen.bases.base_storage import StorageNameSpace
108
from graphgen.bases.datatypes import Chunk
119
from graphgen.engine import op
1210
from graphgen.models import (
1311
JsonKVStorage,
1412
JsonListStorage,
13+
MetaJsonKVStorage,
1514
NetworkXStorage,
1615
OpenAIClient,
1716
Tokenizer,
@@ -55,6 +54,10 @@ def __init__(
5554
)
5655
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
5756

57+
self.meta_storage: MetaJsonKVStorage = MetaJsonKVStorage(
58+
self.working_dir, namespace="_meta"
59+
)
60+
5861
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
5962
self.working_dir, namespace="full_docs"
6063
)
@@ -81,14 +84,13 @@ def __init__(
8184
# webui
8285
self.progress_bar: gr.Progress = progress_bar
8386

84-
@op("insert", deps=[])
87+
@op("read", deps=[])
8588
@async_to_sync_method
86-
async def insert(self, insert_config: Dict):
89+
async def read(self, read_config: Dict):
8790
"""
88-
insert chunks into the graph
91+
read files from input sources
8992
"""
90-
# Step 1: Read files
91-
data = read_files(insert_config["input_file"], self.working_dir)
93+
data = read_files(read_config["input_file"], self.working_dir)
9294
if len(data) == 0:
9395
logger.warning("No data to process")
9496
return
@@ -107,8 +109,8 @@ async def insert(self, insert_config: Dict):
107109

108110
inserting_chunks = await chunk_documents(
109111
new_docs,
110-
insert_config["chunk_size"],
111-
insert_config["chunk_overlap"],
112+
read_config["chunk_size"],
113+
read_config["chunk_overlap"],
112114
self.tokenizer_instance,
113115
self.progress_bar,
114116
)
@@ -124,9 +126,25 @@ async def insert(self, insert_config: Dict):
124126
logger.warning("All chunks are already in the storage")
125127
return
126128

127-
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
129+
await self.full_docs_storage.upsert(new_docs)
130+
await self.full_docs_storage.index_done_callback()
128131
await self.chunks_storage.upsert(inserting_chunks)
132+
await self.chunks_storage.index_done_callback()
133+
134+
@op("build_kg", deps=["read"])
135+
@async_to_sync_method
136+
async def build_kg(self):
137+
"""
138+
build knowledge graph from text chunks
139+
"""
140+
# Step 1: get new chunks according to meta and chunks storage
141+
inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
142+
if len(inserting_chunks) == 0:
143+
logger.warning("All chunks are already in the storage")
144+
return
129145

146+
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
147+
# Step 2: build knowledge graph from new chunks
130148
_add_entities_and_relations = await build_kg(
131149
llm_client=self.synthesizer_llm_client,
132150
kg_instance=self.graph_storage,
@@ -137,23 +155,13 @@ async def insert(self, insert_config: Dict):
137155
logger.warning("No entities or relations extracted from text chunks")
138156
return
139157

140-
await self._insert_done()
158+
# Step 3: mark meta
159+
await self.meta_storage.mark_done(self.chunks_storage)
160+
await self.meta_storage.index_done_callback()
161+
141162
return _add_entities_and_relations
142163

143-
async def _insert_done(self):
144-
tasks = []
145-
for storage_instance in [
146-
self.full_docs_storage,
147-
self.chunks_storage,
148-
self.graph_storage,
149-
self.search_storage,
150-
]:
151-
if storage_instance is None:
152-
continue
153-
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
154-
await asyncio.gather(*tasks)
155-
156-
@op("search", deps=["insert"])
164+
@op("search", deps=["read"])
157165
@async_to_sync_method
158166
async def search(self, search_config: Dict):
159167
logger.info(
@@ -187,9 +195,9 @@ async def search(self, search_config: Dict):
187195
]
188196
)
189197
# TODO: fix insert after search
190-
await self.insert()
198+
# await self.insert()
191199

192-
@op("quiz_and_judge", deps=["insert"])
200+
@op("quiz_and_judge", deps=["build_kg"])
193201
@async_to_sync_method
194202
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
195203
logger.warning(
@@ -228,7 +236,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
228236
logger.info("Restarting synthesizer LLM client.")
229237
self.synthesizer_llm_client.restart()
230238

231-
@op("partition", deps=["insert"])
239+
@op("partition", deps=["build_kg"])
232240
@async_to_sync_method
233241
async def partition(self, partition_config: Dict):
234242
batches = await partition_kg(
@@ -240,7 +248,7 @@ async def partition(self, partition_config: Dict):
240248
await self.partition_storage.upsert(batches)
241249
return batches
242250

243-
@op("generate", deps=["insert", "partition"])
251+
@op("generate", deps=["partition"])
244252
@async_to_sync_method
245253
async def generate(self, generate_config: Dict):
246254

graphgen/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@
3030
from .search.web.bing_search import BingSearch
3131
from .search.web.google_search import GoogleSearch
3232
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
33-
from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
33+
from .storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage, NetworkXStorage
3434
from .tokenizer import Tokenizer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .json_storage import JsonKVStorage, JsonListStorage
1+
from .json_storage import JsonKVStorage, JsonListStorage, MetaJsonKVStorage
22
from .networkx_storage import NetworkXStorage

graphgen/models/storage/json_storage.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ async def filter_keys(self, data: list[str]) -> set[str]:
4444

4545
async def upsert(self, data: dict):
4646
left_data = {k: v for k, v in data.items() if k not in self._data}
47-
self._data.update(left_data)
47+
if left_data:
48+
self._data.update(left_data)
4849
return left_data
4950

5051
async def drop(self):
51-
self._data = {}
52+
if self._data:
53+
self._data.clear()
5254

5355

5456
@dataclass
@@ -87,3 +89,23 @@ async def upsert(self, data: list):
8789

8890
async def drop(self):
8991
self._data = []
92+
93+
94+
@dataclass
95+
class MetaJsonKVStorage(JsonKVStorage):
96+
def __post_init__(self):
97+
self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
98+
self._data = load_json(self._file_name) or {}
99+
logger.info("Load KV %s with %d data", self.namespace, len(self._data))
100+
101+
async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
102+
new_data = {}
103+
for k, v in storage_instance.data.items():
104+
if k not in self._data:
105+
new_data[k] = v
106+
return new_data
107+
108+
async def mark_done(self, storage_instance: "JsonKVStorage"):
109+
new_data = await self.get_new_data(storage_instance)
110+
if new_data:
111+
self._data.update(new_data)

0 commit comments

Comments
 (0)