|
| 1 | +import json |
| 2 | + |
1 | 3 | from graphgen.bases import BaseExtractor, BaseLLMWrapper |
| 4 | +from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT |
| 5 | +from graphgen.utils import compute_dict_hash, detect_main_language |
2 | 6 |
|
3 | 7 |
|
4 | 8 | class SchemaGuidedExtractor(BaseExtractor): |
@@ -33,9 +37,42 @@ class SchemaGuidedExtractor(BaseExtractor): |
33 | 37 | def __init__(self, llm_client: BaseLLMWrapper, schema: dict): |
34 | 38 | super().__init__(llm_client) |
35 | 39 | self.schema = schema |
| 40 | + self.required_keys = self.schema.get("required") |
| 41 | + if not self.required_keys: |
| 42 | + # If no required keys are specified, use all keys from the schema as default |
| 43 | + self.required_keys = list(self.schema.get("properties", {}).keys()) |
36 | 44 |
|
37 | 45 | def build_prompt(self, text: str) -> str: |
38 | | - pass |
| 46 | + schema_explanation = "" |
| 47 | + for field, details in self.schema.get("properties", {}).items(): |
| 48 | + description = details.get("description", "No description provided.") |
| 49 | + schema_explanation += f'- "{field}": {description}\n' |
| 50 | + |
| 51 | + lang = detect_main_language(text) |
| 52 | + |
| 53 | + prompt = SCHEMA_GUIDED_EXTRACTION_PROMPT[lang].format( |
| 54 | + field=self.schema.get("name", "the document"), |
| 55 | + schema_explanation=schema_explanation, |
| 56 | + examples="", |
| 57 | + text=text, |
| 58 | + ) |
| 59 | + return prompt |
39 | 60 |
|
40 | 61 | async def extract(self, chunk: dict) -> dict: |
41 | | - print(chunk) |
| 62 | + text = chunk.get("text", "") |
| 63 | + prompt = self.build_prompt(text) |
| 64 | + response = await self.llm_client.generate_answer(prompt) |
| 65 | + try: |
| 66 | + extracted_info = json.loads(response) |
| 67 | + # Ensure all required keys are present |
| 68 | + for key in self.required_keys: |
| 69 | + if key not in extracted_info: |
| 70 | + extracted_info[key] = "" |
| 71 | + if any(extracted_info[key] == "" for key in self.required_keys): |
| 72 | + return {} |
| 73 | + main_keys_info = {key: extracted_info[key] for key in self.required_keys} |
| 74 | + return {compute_dict_hash(main_keys_info): extracted_info} |
| 75 | + except json.JSONDecodeError: |
| 76 | + return {} |
| 77 | + |
| 78 | + # async def merge_extractions(self): |
0 commit comments