Skip to content

Commit 3f79260

Browse files
feat: complete extract_info pipeline
1 parent 912508c commit 3f79260

10 files changed

Lines changed: 94 additions & 30 deletions

File tree

graphgen/configs/aggregated_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/configs/atomic_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
5+
6+
- name: chunk
7+
params:
58
chunk_size: 1024 # chunk size for text splitting
69
chunk_overlap: 100 # chunk overlap for text splitting
710

graphgen/configs/cot_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/configs/multi_hop_config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
6+
- name: chunk
7+
params:
58
chunk_size: 1024 # chunk size for text splitting
69
chunk_overlap: 100 # chunk overlap for text splitting
710

graphgen/configs/schema_guided_config.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 20480 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 20480
9+
chunk_overlap: 2000
10+
separators: []
711

812
- name: extract
913
params:

graphgen/configs/vqa_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ pipeline:
22
- name: read
33
params:
44
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 1024 # chunk size for text splitting
6-
chunk_overlap: 100 # chunk overlap for text splitting
5+
6+
- name: chunk
7+
params:
8+
chunk_size: 1024 # chunk size for text splitting
9+
chunk_overlap: 100 # chunk overlap for text splitting
710

811
- name: build_kg
912

graphgen/graphgen.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.search_storage: JsonKVStorage = JsonKVStorage(
7272
self.working_dir, namespace="search"
7373
)
74+
7475
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7576
self.working_dir, namespace="rephrase"
7677
)
@@ -81,6 +82,10 @@ def __init__(
8182
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
8283
namespace="qa",
8384
)
85+
self.extract_storage: JsonKVStorage = JsonKVStorage(
86+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
87+
namespace="extraction",
88+
)
8489

8590
# webui
8691
self.progress_bar: gr.Progress = progress_bar
@@ -104,16 +109,30 @@ async def read(self, read_config: Dict):
104109
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
105110
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
106111

112+
if len(new_docs) == 0:
113+
logger.warning("All documents are already in the storage")
114+
return
115+
116+
await self.full_docs_storage.upsert(new_docs)
117+
await self.full_docs_storage.index_done_callback()
118+
119+
@op("chunk", deps=["read"])
120+
@async_to_sync_method
121+
async def chunk(self, chunk_config: Dict):
122+
"""
123+
chunk documents into smaller pieces from full_docs_storage if not already present
124+
"""
125+
126+
new_docs = await self.meta_storage.get_new_data(self.full_docs_storage)
107127
if len(new_docs) == 0:
108128
logger.warning("All documents are already in the storage")
109129
return
110130

111131
inserting_chunks = await chunk_documents(
112132
new_docs,
113-
read_config["chunk_size"],
114-
read_config["chunk_overlap"],
115133
self.tokenizer_instance,
116134
self.progress_bar,
135+
**chunk_config,
117136
)
118137

119138
_add_chunk_keys = await self.chunks_storage.filter_keys(
@@ -127,12 +146,12 @@ async def read(self, read_config: Dict):
127146
logger.warning("All chunks are already in the storage")
128147
return
129148

130-
await self.full_docs_storage.upsert(new_docs)
131-
await self.full_docs_storage.index_done_callback()
132149
await self.chunks_storage.upsert(inserting_chunks)
133150
await self.chunks_storage.index_done_callback()
151+
await self.meta_storage.mark_done(self.full_docs_storage)
152+
await self.meta_storage.index_done_callback()
134153

135-
@op("build_kg", deps=["read"])
154+
@op("build_kg", deps=["chunk"])
136155
@async_to_sync_method
137156
async def build_kg(self):
138157
"""
@@ -162,7 +181,7 @@ async def build_kg(self):
162181

163182
return _add_entities_and_relations
164183

165-
@op("search", deps=["read"])
184+
@op("search", deps=["chunk"])
166185
@async_to_sync_method
167186
async def search(self, search_config: Dict):
168187
logger.info(
@@ -249,7 +268,7 @@ async def partition(self, partition_config: Dict):
249268
await self.partition_storage.upsert(batches)
250269
return batches
251270

252-
@op("extract", deps=["read"])
271+
@op("extract", deps=["chunk"])
253272
@async_to_sync_method
254273
async def extract(self, extract_config: Dict):
255274
logger.info("Extracting information from given chunks...")
@@ -263,7 +282,11 @@ async def extract(self, extract_config: Dict):
263282
if not results:
264283
logger.warning("No information extracted")
265284
return
266-
print(results)
285+
286+
await self.extract_storage.upsert(results)
287+
await self.extract_storage.index_done_callback()
288+
await self.meta_storage.mark_done(self.chunks_storage)
289+
await self.meta_storage.index_done_callback()
267290

268291
@op("generate", deps=["partition"])
269292
@async_to_sync_method

graphgen/models/extractor/schema_guided_extractor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
2+
from typing import Dict, List
23

34
from graphgen.bases import BaseExtractor, BaseLLMWrapper
45
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
5-
from graphgen.utils import compute_dict_hash, detect_main_language
6+
from graphgen.utils import compute_dict_hash, detect_main_language, logger
67

78

89
class SchemaGuidedExtractor(BaseExtractor):
@@ -69,10 +70,32 @@ async def extract(self, chunk: dict) -> dict:
6970
if key not in extracted_info:
7071
extracted_info[key] = ""
7172
if any(extracted_info[key] == "" for key in self.required_keys):
73+
logger.debug("Missing required keys in extraction: %s", extracted_info)
7274
return {}
7375
main_keys_info = {key: extracted_info[key] for key in self.required_keys}
74-
return {compute_dict_hash(main_keys_info): extracted_info}
76+
logger.debug("Extracted info: %s", extracted_info)
77+
return {compute_dict_hash(main_keys_info, prefix="extract"): extracted_info}
7578
except json.JSONDecodeError:
79+
logger.error("Failed to parse extraction response: %s", response)
7680
return {}
7781

78-
# async def merge_extractions(self):
82+
async def merge_extractions(
83+
self, extraction_list: List[Dict[str, dict]]
84+
) -> Dict[str, dict]:
85+
"""
86+
Merge multiple extraction results based on their hashes.
87+
:param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
88+
:return: Merged extraction results.
89+
"""
90+
merged: Dict[str, dict] = {}
91+
for ext in extraction_list:
92+
for h, rec in ext.items():
93+
if h not in merged:
94+
merged[h] = rec.copy()
95+
else:
96+
for k, v in rec.items():
97+
if k not in merged[h] or merged[h][k] == v:
98+
merged[h][k] = v
99+
else:
100+
merged[h][k] = f"{merged[h][k]}<SEP>{v}"
101+
return merged

graphgen/operators/extract/extract_info.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import gradio as gr
44

55
from graphgen.bases import BaseKVStorage, BaseLLMWrapper
6-
from graphgen.bases.datatypes import Chunk
76
from graphgen.models.extractor import SchemaGuidedExtractor
87
from graphgen.utils import logger, run_concurrent
98

@@ -34,7 +33,7 @@ async def extract_info(
3433

3534
chunks = await chunk_storage.get_all()
3635
chunks = [{k: v} for k, v in chunks.items()]
37-
logger.info(f"Start extracting information from {len(chunks)} chunks")
36+
logger.info("Start extracting information from %d chunks", len(chunks))
3837

3938
results = await run_concurrent(
4039
extractor.extract,
@@ -43,8 +42,6 @@ async def extract_info(
4342
unit="chunk",
4443
progress_bar=progress_bar,
4544
)
46-
print(results)
4745

48-
# TODO: 对results合并,去重
49-
50-
return []
46+
results = await extractor.merge_extractions(results)
47+
return results

graphgen/operators/split/split_chunks.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list:
3131
f"Unsupported language: {language}. "
3232
f"Supported languages are: {list(_MAPPING.keys())}"
3333
)
34-
splitter = _get_splitter(language, frozenset(kwargs.items()))
34+
frozen_kwargs = frozenset(
35+
(k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
36+
)
37+
splitter = _get_splitter(language, frozen_kwargs)
3538
return splitter.split_text(text)
3639

3740

3841
async def chunk_documents(
3942
new_docs: dict,
40-
chunk_size: int = 1024,
41-
chunk_overlap: int = 100,
4243
tokenizer_instance: Tokenizer = None,
4344
progress_bar=None,
45+
**kwargs,
4446
) -> dict:
4547
inserting_chunks = {}
4648
cur_index = 1
@@ -51,11 +53,11 @@ async def chunk_documents(
5153
doc_type = doc.get("type")
5254
if doc_type == "text":
5355
doc_language = detect_main_language(doc["content"])
56+
5457
text_chunks = split_chunks(
5558
doc["content"],
5659
language=doc_language,
57-
chunk_size=chunk_size,
58-
chunk_overlap=chunk_overlap,
60+
**kwargs,
5961
)
6062

6163
chunks = {

0 commit comments

Comments
 (0)