Skip to content

Commit f89a320

Browse files
wip
1 parent f6f99fa commit f89a320

8 files changed

Lines changed: 77 additions & 13 deletions

File tree

graphgen/bases/base_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, llm_client: BaseLLMWrapper):
1414
self.llm_client = llm_client
1515

1616
@abstractmethod
17-
def extract(self, text_or_documents: str) -> Any:
17+
async def extract(self, chunk: dict) -> Any:
1818
"""Extract information from the given text"""
1919

2020
@abstractmethod
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
pipeline:
2-
- name: insert
2+
- name: read
33
params:
4-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5-
chunk_size: 10240 # chunk size for text splitting
4+
input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
chunk_size: 20480 # chunk size for text splitting
66
chunk_overlap: 100 # chunk overlap for text splitting
77

88
- name: extract
99
params:
1010
method: schema_guided # extraction method, support: schema_guided
11-
schema_file: resources/schemas/legal_contract.json # schema file path for schema_guided method
11+
schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method

graphgen/graphgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ async def partition(self, partition_config: Dict):
249249
await self.partition_storage.upsert(batches)
250250
return batches
251251

252-
@op("extract", deps=["insert"])
252+
@op("extract", deps=["read"])
253253
@async_to_sync_method
254254
async def extract(self, extract_config: Dict):
255255
logger.info("Extracting information from given chunks...")

graphgen/models/extractor/schema_guided_extractor.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import json
2+
13
from graphgen.bases import BaseExtractor, BaseLLMWrapper
4+
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
5+
from graphgen.utils import compute_dict_hash, detect_main_language
26

37

48
class SchemaGuidedExtractor(BaseExtractor):
@@ -33,9 +37,42 @@ class SchemaGuidedExtractor(BaseExtractor):
3337
def __init__(self, llm_client: BaseLLMWrapper, schema: dict):
3438
super().__init__(llm_client)
3539
self.schema = schema
40+
self.required_keys = self.schema.get("required")
41+
if not self.required_keys:
42+
# If no required keys are specified, use all keys from the schema as default
43+
self.required_keys = list(self.schema.get("properties", {}).keys())
3644

3745
def build_prompt(self, text: str) -> str:
38-
pass
46+
schema_explanation = ""
47+
for field, details in self.schema.get("properties", {}).items():
48+
description = details.get("description", "No description provided.")
49+
schema_explanation += f'- "{field}": {description}\n'
50+
51+
lang = detect_main_language(text)
52+
53+
prompt = SCHEMA_GUIDED_EXTRACTION_PROMPT[lang].format(
54+
field=self.schema.get("name", "the document"),
55+
schema_explanation=schema_explanation,
56+
examples="",
57+
text=text,
58+
)
59+
return prompt
3960

4061
async def extract(self, chunk: dict) -> dict:
41-
print(chunk)
62+
text = chunk.get("text", "")
63+
prompt = self.build_prompt(text)
64+
response = await self.llm_client.generate_answer(prompt)
65+
try:
66+
extracted_info = json.loads(response)
67+
# Ensure all required keys are present
68+
for key in self.required_keys:
69+
if key not in extracted_info:
70+
extracted_info[key] = ""
71+
if any(extracted_info[key] == "" for key in self.required_keys):
72+
return {}
73+
main_keys_info = {key: extracted_info[key] for key in self.required_keys}
74+
return {compute_dict_hash(main_keys_info): extracted_info}
75+
except json.JSONDecodeError:
76+
return {}
77+
78+
# async def merge_extractions(self):

graphgen/operators/extract/extract_info.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
import json
22

33
import gradio as gr
44

@@ -25,7 +25,9 @@ async def extract_info(
2525

2626
method = extract_config.get("method")
2727
if method == "schema_guided":
28-
schema = extract_config.get("schema")
28+
schema_file = extract_config.get("schema_file")
29+
with open(schema_file, "r", encoding="utf-8") as f:
30+
schema = json.load(f)
2931
extractor = SchemaGuidedExtractor(llm_client, schema)
3032
else:
3133
raise ValueError(f"Unsupported extraction method: {method}")
@@ -41,6 +43,7 @@ async def extract_info(
4143
unit="chunk",
4244
progress_bar=progress_bar,
4345
)
46+
print(results)
4447

4548
# TODO: 对results合并,去重
4649

graphgen/templates/extraction/schema_guided_extraction.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,21 @@
1111
- Consider the context of the entire document when determining relevance.
1212
- Do not be verbose, only respond with the correct format and information.
1313
- Some docs may have multiple relevant excerpts -- include all that apply.
14-
- Some questions may have no relevant excerpts -- just return ["N/A"].
14+
- Some questions may have no relevant excerpts -- just return "".
1515
- Do not include additional JSON keys beyond the ones listed here.
1616
- Do not include the same key multiple times in the JSON.
1717
- Use English for your response.
1818
1919
Expected JSON keys and explanation of what they are:
2020
{schema_explanation}
2121
22+
Expected format:
23+
{{
24+
"key1": "value1",
25+
"key2": "value2",
26+
...
27+
}}
28+
2229
{examples}
2330
2431
Document to extract from:
@@ -37,14 +44,21 @@
3744
- 在确定相关性时,考虑整份文件的上下文。
3845
- 不要冗长,只需以正确的格式和信息进行回应。
3946
- 有些文件可能有多个相关摘录——请包含所有适用的内容。
40-
- 有些问题可能没有相关摘录——只需返回["N/A"]
47+
- 有些问题可能没有相关摘录——只需返回""
4148
- 不要在JSON中包含除列出的键之外的其他键。
4249
- 不要多次包含同一个键。
4350
- 使用中文回答。
4451
4552
预期的JSON键及其说明:
4653
{schema_explanation}
4754
55+
预期格式:
56+
{{
57+
"key1": "value1",
58+
"key2": "value2",
59+
...
60+
}}
61+
4862
{examples}
4963
要提取的文件:
5064
{text}

graphgen/utils/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
split_string_by_multi_markers,
1010
write_json,
1111
)
12-
from .hash import compute_args_hash, compute_content_hash, compute_mm_hash
12+
from .hash import (
13+
compute_args_hash,
14+
compute_content_hash,
15+
compute_dict_hash,
16+
compute_mm_hash,
17+
)
1318
from .help_nltk import NLTKHelper
1419
from .log import logger, parse_log, set_logger
1520
from .loop import create_event_loop

graphgen/utils/hash.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,8 @@ def compute_mm_hash(item, prefix: str = ""):
2121
else:
2222
content = str(item)
2323
return prefix + md5(content.encode()).hexdigest()
24+
25+
26+
def compute_dict_hash(d: dict, prefix: str = ""):
27+
items = tuple(sorted(d.items()))
28+
return prefix + md5(str(items).encode()).hexdigest()

0 commit comments

Comments
 (0)