Skip to content

Commit ee0639d

Browse files
refactor: refactor quiz to accomodata ray data engine
1 parent 3edbb81 commit ee0639d

7 files changed

Lines changed: 120 additions & 94 deletions

File tree

graphgen/operators/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
from .build_kg import build_kg
1+
from .build_kg import BuildKGService
2+
from .chunk import ChunkService
23
from .extract import extract_info
34
from .generate import generate_qas
4-
from .init import init_llm
55
from .partition import partition_kg
6-
from .quiz_and_judge import judge_statement, quiz
6+
from .quiz import QuizService
77
from .read import read
88
from .search import search_all
9-
from .split import chunk_documents
9+
10+
operators = {
11+
"read": read,
12+
"chunk": ChunkService,
13+
"build_kg": BuildKGService,
14+
"quiz": QuizService,
15+
"extract_info": extract_info,
16+
"search_all": search_all,
17+
"partition_kg": partition_kg,
18+
"generate_qas": generate_qas,
19+
}

graphgen/operators/build_kg/build_kg_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
2424

2525
# consume the chunks and build kg
2626
self.build_kg(docs)
27-
return pd.DataFrame()
27+
return pd.DataFrame([{"status": "kg_building_completed"}])
2828

2929
def build_kg(self, chunks: List[Chunk]) -> None:
3030
"""

graphgen/operators/evaluate/__init__.py

Whitespace-only changes.

graphgen/operators/judge/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .quiz import QuizService

graphgen/operators/quiz/quiz.py

Lines changed: 102 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,107 @@
1-
from collections import defaultdict
2-
3-
import gradio as gr
4-
5-
from graphgen.bases import BaseLLMWrapper
6-
from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator
7-
from graphgen.utils import logger, run_concurrent
8-
9-
10-
async def quiz(
11-
synth_llm_client: BaseLLMWrapper,
12-
graph_storage: NetworkXStorage,
13-
rephrase_storage: JsonKVStorage,
14-
max_samples: int = 1,
15-
progress_bar: gr.Progress = None,
16-
) -> JsonKVStorage:
17-
"""
18-
Get all edges and quiz them using QuizGenerator.
19-
20-
:param synth_llm_client: generate statements
21-
:param graph_storage: graph storage instance
22-
:param rephrase_storage: rephrase storage instance
23-
:param max_samples: max samples for each edge
24-
:param progress_bar
25-
:return:
26-
"""
27-
28-
generator = QuizGenerator(synth_llm_client)
29-
30-
async def _process_single_quiz(item: tuple[str, str, str]):
31-
description, template_type, gt = item
32-
try:
33-
# if rephrase_storage exists already, directly get it
34-
descriptions = rephrase_storage.get_by_id(description)
35-
if descriptions:
36-
return None
37-
38-
prompt = generator.build_prompt_for_description(description, template_type)
39-
new_description = await synth_llm_client.generate_answer(
40-
prompt, temperature=1
41-
)
42-
rephrased_text = generator.parse_rephrased_text(new_description)
43-
return {description: [(rephrased_text, gt)]}
44-
45-
except Exception as e: # pylint: disable=broad-except
46-
logger.error("Error when quizzing description %s: %s", description, e)
1+
from collections.abc import Iterable
2+
3+
import pandas as pd
4+
5+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
6+
from graphgen.common import init_llm, init_storage
7+
from graphgen.models import QuizGenerator
8+
from graphgen.utils import compute_content_hash, logger, run_concurrent
9+
10+
11+
class QuizService:
12+
def __init__(self, working_dir: str = "cache", quiz_samples: int = 1):
13+
self.quiz_samples = quiz_samples
14+
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
15+
self.graph_storage: BaseGraphStorage = init_storage(
16+
backend="networkx", working_dir=working_dir, namespace="graph"
17+
)
18+
# { _description_id: { "description": str, "quizzes": List[Tuple[str, str]] } }
19+
self.quiz_storage: BaseKVStorage = init_storage(
20+
backend="json_kv", working_dir=working_dir, namespace="quiz"
21+
)
22+
self.generator = QuizGenerator(self.llm_client)
23+
24+
self.concurrency_limit = 20
25+
26+
def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
27+
# this operator does not consume any batch data
28+
# but for compatibility we keep the interface
29+
_ = batch.to_dict(orient="records")
30+
31+
yield from self.quiz()
32+
33+
async def _process_single_quiz(self, item: str) -> dict | None:
34+
# if quiz in quiz_storage exists already, directly get it
35+
_description_id = compute_content_hash(item)
36+
if self.quiz_storage.get_by_id(_description_id):
4737
return None
4838

49-
edges = graph_storage.get_all_edges()
50-
nodes = graph_storage.get_all_nodes()
51-
52-
results = defaultdict(list)
53-
items = []
54-
for edge in edges:
55-
edge_data = edge[2]
56-
description = edge_data["description"]
57-
58-
results[description] = [(description, "yes")]
59-
60-
for i in range(max_samples):
39+
tasks = []
40+
for i in range(self.quiz_samples):
6141
if i > 0:
62-
items.append((description, "TEMPLATE", "yes"))
63-
items.append((description, "ANTI_TEMPLATE", "no"))
64-
65-
for node in nodes:
66-
node_data = node[1]
67-
description = node_data["description"]
42+
tasks.append((item, "TEMPLATE", "yes"))
43+
tasks.append((item, "ANTI_TEMPLATE", "no"))
44+
try:
45+
quizzes = []
46+
for description, template_type, gt in tasks:
47+
prompt = self.generator.build_prompt_for_description(
48+
description, template_type
49+
)
50+
new_description = await self.llm_client.generate_answer(
51+
prompt, temperature=1
52+
)
53+
rephrased_text = self.generator.parse_rephrased_text(new_description)
54+
quizzes.append((rephrased_text, gt))
55+
return {
56+
"_description_id": _description_id,
57+
"description": item,
58+
"quizzes": quizzes,
59+
}
60+
except Exception as e:
61+
logger.error("Error when quizzing description %s: %s", item, e)
62+
return None
6863

69-
results[description] = [(description, "yes")]
64+
def quiz(self) -> Iterable[pd.DataFrame]:
65+
"""
66+
Get all nodes and edges and quiz their descriptions using QuizGenerator.
67+
"""
68+
edges = self.graph_storage.get_all_edges()
69+
nodes = self.graph_storage.get_all_nodes()
70+
71+
items = []
72+
73+
for edge in edges:
74+
edge_data = edge[2]
75+
description = edge_data["description"]
76+
items.append(description)
77+
78+
for node in nodes:
79+
node_data = node[1]
80+
description = node_data["description"]
81+
items.append(description)
82+
83+
logger.info("Total descriptions to quiz: %d", len(items))
84+
85+
for i in range(0, len(items), self.concurrency_limit):
86+
batch_items = items[i : i + self.concurrency_limit]
87+
batch_results = run_concurrent(
88+
self._process_single_quiz,
89+
batch_items,
90+
desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})",
91+
unit="description",
92+
)
7093

71-
for i in range(max_samples):
72-
if i > 0:
73-
items.append((description, "TEMPLATE", "yes"))
74-
items.append((description, "ANTI_TEMPLATE", "no"))
75-
76-
quiz_results = await run_concurrent(
77-
_process_single_quiz,
78-
items,
79-
desc="Quizzing descriptions",
80-
unit="description",
81-
progress_bar=progress_bar,
82-
)
83-
84-
for new_result in quiz_results:
85-
if new_result:
86-
for key, value in new_result.items():
87-
results[key].extend(value)
88-
89-
for key, value in results.items():
90-
results[key] = list(set(value))
91-
rephrase_storage.upsert({key: results[key]})
92-
93-
return rephrase_storage
94+
final_results = []
95+
for new_result in batch_results:
96+
if new_result:
97+
self.quiz_storage.upsert(
98+
{
99+
new_result["_description_id"]: {
100+
"description": new_result["description"],
101+
"quizzes": new_result["quizzes"],
102+
}
103+
}
104+
)
105+
final_results.append(new_result)
106+
self.quiz_storage.index_done_callback()
107+
yield pd.DataFrame(final_results)

graphgen/utils/run_concurrent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tqdm.asyncio import tqdm as tqdm_async
55

66
from graphgen.utils.log import logger
7+
78
from .loop import create_event_loop
89

910
T = TypeVar("T")
@@ -27,7 +28,7 @@ async def _run_all():
2728
try:
2829
result = await future
2930
results.append(result)
30-
except Exception as e: # pylint: disable=broad-except
31+
except Exception as e:
3132
logger.exception("Task failed: %s", e)
3233
results.append(e)
3334

0 commit comments

Comments
 (0)