Skip to content

Commit 0cefddf

Browse files
wip: add extract_info
1 parent 011a1e5 commit 0cefddf

9 files changed

Lines changed: 85 additions & 57 deletions

File tree

graphgen/bases/base_storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ async def get_by_ids(
4545
) -> list[Union[T, None]]:
4646
raise NotImplementedError
4747

48+
async def get_all(self) -> dict[str, T]:
49+
raise NotImplementedError
50+
4851
async def filter_keys(self, data: list[str]) -> set[str]:
4952
"""return un-exist keys"""
5053
raise NotImplementedError
Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
read:
2-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
3-
split:
4-
chunk_size: 10240 # chunk size for text splitting
5-
chunk_overlap: 100 # chunk overlap for text splitting
6-
extract:
7-
method: schema_guided # extraction method, support: schema_guided
8-
schema_file: resources/schemas/legal_contract.json # schema file path for schema_guided method
1+
pipeline:
2+
- name: insert
3+
params:
4+
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
5+
chunk_size: 10240 # chunk size for text splitting
6+
chunk_overlap: 100 # chunk overlap for text splitting
7+
8+
- name: extract
9+
params:
10+
method: schema_guided # extraction method, support: schema_guided
11+
schema_file: resources/schemas/legal_contract.json # schema file path for schema_guided method

graphgen/graphgen.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from graphgen.operators import (
2020
build_kg,
2121
chunk_documents,
22+
extract_info,
2223
generate_qas,
2324
init_llm,
2425
judge_statement,
@@ -240,6 +241,22 @@ async def partition(self, partition_config: Dict):
240241
await self.partition_storage.upsert(batches)
241242
return batches
242243

244+
@op("extract", deps=["insert"])
245+
@async_to_sync_method
246+
async def extract(self, extract_config: Dict):
247+
logger.info("Extracting information from given chunks...")
248+
249+
results = await extract_info(
250+
self.synthesizer_llm_client,
251+
self.chunks_storage,
252+
extract_config,
253+
progress_bar=self.progress_bar,
254+
)
255+
if not results:
256+
logger.warning("No information extracted")
257+
return
258+
print(results)
259+
243260
@op("generate", deps=["insert", "partition"])
244261
@async_to_sync_method
245262
async def generate(self, generate_config: Dict):

graphgen/models/extractor/schema_guided_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ def __init__(self, llm_client: BaseLLMWrapper, schema: dict):
3737
def build_prompt(self, text: str) -> str:
3838
pass
3939

40-
def extract(self, text_or_documents: str) -> dict:
41-
pass
40+
async def extract(self, chunk: dict) -> dict:
41+
print(chunk)

graphgen/models/storage/json_storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ async def get_by_ids(self, ids, fields=None) -> list:
3939
for id in ids
4040
]
4141

42+
async def get_all(self) -> dict[str, str]:
43+
return self._data
44+
4245
async def filter_keys(self, data: list[str]) -> set[str]:
4346
return {s for s in data if s not in self._data}
4447

graphgen/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .build_kg import build_kg
2+
from .extract import extract_info
23
from .generate import generate_qas
34
from .init import init_llm
45
from .judge import judge_statement
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .extract_info import extract_info

graphgen/operators/extract/extract.py

Lines changed: 0 additions & 47 deletions
This file was deleted.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import List
2+
3+
import gradio as gr
4+
5+
from graphgen.bases import BaseKVStorage, BaseLLMWrapper
6+
from graphgen.bases.datatypes import Chunk
7+
from graphgen.models.extractor import SchemaGuidedExtractor
8+
from graphgen.utils import logger, run_concurrent
9+
10+
11+
async def extract_info(
12+
llm_client: BaseLLMWrapper,
13+
chunk_storage: BaseKVStorage,
14+
extract_config: dict,
15+
progress_bar: gr.Progress = None,
16+
):
17+
"""
18+
Extract information from chunks
19+
:param llm_client: LLM client
20+
:param chunk_storage: storage for chunks
21+
:param extract_config
22+
:param progress_bar
23+
:return: extracted information
24+
"""
25+
26+
method = extract_config.get("method")
27+
if method == "schema_guided":
28+
schema = extract_config.get("schema")
29+
extractor = SchemaGuidedExtractor(llm_client, schema)
30+
else:
31+
raise ValueError(f"Unsupported extraction method: {method}")
32+
33+
chunks = await chunk_storage.get_all()
34+
chunks = [{k: v} for k, v in chunks.items()]
35+
logger.info(f"Start extracting information from {len(chunks)} chunks")
36+
37+
results = await run_concurrent(
38+
extractor.extract,
39+
chunks,
40+
desc="Extracting information",
41+
unit="chunk",
42+
progress_bar=progress_bar,
43+
)
44+
45+
# TODO: 对results合并,去重
46+
47+
return []

0 commit comments

Comments
 (0)