Skip to content

Commit 39e2352

Browse files
superfartherbeanbunChenZiHong-Gavingemini-code-assist[bot]
authored
feat: support synthesizing masked fill_in_blank QA pairs (#173)
* feat: support synthesizing masked fill_in_blank QA pairs * style: fix formatting issues * feat: support partitioning the graph into quintuples * update README * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: beanbun <yuanzhonghang@pjlab.org.cn> Co-authored-by: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 9e994a4 commit 39e2352

12 files changed

Lines changed: 358 additions & 1 deletion

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Generate Masked Fill-in-blank QAs
2+
In this module, we generate fill-in-blank QAs from unstructured corpora by randomly masking core entities in a knowledge graph. The key is that a rule-based validator can automatically verify the answers to these questions. For example:
3+
> **Question:** Hematogenous long-bone osteomyelitis is an infection of the bone, primarily affecting the long bones, and often results from blood-borne pathogens. This condition is characterized by several key symptoms, including ___ and swelling. ___ is a prominent symptom in both primary and recurrent cases of hematogenous long-bone osteomyelitis, manifesting as persistent discomfort in the affected area.
4+
> **Answer:** pain
5+
6+
Because the answer of these questions can be easily verified, they are well-suited for RLVR (Reinforcement Learning with Verifiable Rewards).
7+
8+
For more details, please see our paper "Knowledge-to-Verification: Exploring RLVR for LLMs in Knowledge-Intensive Domains". It has been accepted to the ACL 2026 Main Conference, and we will update the link soon.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: partition
36+
op_name: partition
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
params:
41+
method: quintuple
42+
43+
- id: generate
44+
op_name: generate
45+
type: map_batch
46+
dependencies:
47+
- partition
48+
execution_params:
49+
replicas: 1
50+
batch_size: 128
51+
save_output: true # save output
52+
params:
53+
method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa
54+
data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs

graphgen/bases/base_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,10 @@ def format_generation_results(
7474
{"role": "assistant", "content": answer},
7575
]
7676
}
77+
78+
if output_data_format == "QA_pairs":
79+
return {
80+
"question": question,
81+
"answer": answer,
82+
}
7783
raise ValueError(f"Unknown output data format: {output_data_format}")

graphgen/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AtomicGenerator,
1616
CoTGenerator,
1717
FillInBlankGenerator,
18+
MaskedFillInBlankGenerator,
1819
MultiAnswerGenerator,
1920
MultiChoiceGenerator,
2021
MultiHopGenerator,
@@ -30,6 +31,8 @@
3031
DFSPartitioner,
3132
ECEPartitioner,
3233
LeidenPartitioner,
34+
QuintuplePartitioner,
35+
TriplePartitioner,
3336
)
3437
from .reader import (
3538
CSVReader,
@@ -73,6 +76,7 @@
7376
"QuizGenerator": ".generator",
7477
"TrueFalseGenerator": ".generator",
7578
"VQAGenerator": ".generator",
79+
"MaskedFillInBlankGenerator": ".generator",
7680
# KG Builder
7781
"LightRAGKGBuilder": ".kg_builder",
7882
"MMKGBuilder": ".kg_builder",
@@ -86,6 +90,8 @@
8690
"DFSPartitioner": ".partitioner",
8791
"ECEPartitioner": ".partitioner",
8892
"LeidenPartitioner": ".partitioner",
93+
"TriplePartitioner": ".partitioner",
94+
"QuintuplePartitioner": ".partitioner",
8995
# Reader
9096
"CSVReader": ".reader",
9197
"JSONReader": ".reader",

graphgen/models/generator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .quiz_generator import QuizGenerator
99
from .true_false_generator import TrueFalseGenerator
1010
from .vqa_generator import VQAGenerator
11+
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import random
2+
import re
3+
from typing import Any, Optional
4+
5+
from graphgen.bases import BaseGenerator
6+
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
7+
from graphgen.utils import detect_main_language, logger
8+
9+
random.seed(42)
10+
11+
12+
class MaskedFillInBlankGenerator(BaseGenerator):
13+
"""
14+
Masked Fill-in-blank Generator follows a TWO-STEP process:
15+
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
16+
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
17+
"""
18+
19+
@staticmethod
20+
def build_prompt(
21+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
22+
) -> str:
23+
"""
24+
Build prompts for REPHRASE.
25+
:param batch
26+
:return:
27+
"""
28+
nodes, edges = batch
29+
entities_str = "\n".join(
30+
[
31+
f"{index + 1}. {node[0]}: {node[1]['description']}"
32+
for index, node in enumerate(nodes)
33+
]
34+
)
35+
relations_str = "\n".join(
36+
[
37+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
38+
for index, edge in enumerate(edges)
39+
]
40+
)
41+
language = detect_main_language(entities_str + relations_str)
42+
43+
# TODO: configure add_context
44+
# if add_context:
45+
# original_ids = [
46+
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
47+
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
48+
# original_ids = list(set(original_ids))
49+
# original_text = await text_chunks_storage.get_by_ids(original_ids)
50+
# original_text = "\n".join(
51+
# [
52+
# f"{index + 1}. {text['content']}"
53+
# for index, text in enumerate(original_text)
54+
# ]
55+
# )
56+
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
57+
entities=entities_str, relationships=relations_str
58+
)
59+
return prompt
60+
61+
@staticmethod
62+
def parse_rephrased_text(response: str) -> Optional[str]:
63+
"""
64+
Parse the rephrased text from the response.
65+
:param response:
66+
:return: rephrased text
67+
"""
68+
rephrased_match = re.search(
69+
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
70+
)
71+
if rephrased_match:
72+
rephrased_text = rephrased_match.group(1).strip()
73+
else:
74+
logger.warning("Failed to parse rephrased text from response: %s", response)
75+
return None
76+
return rephrased_text.strip('"').strip("'")
77+
78+
@staticmethod
79+
def parse_response(response: str) -> dict:
80+
pass
81+
82+
async def generate(
83+
self,
84+
batch: tuple[
85+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
86+
],
87+
) -> list[dict]:
88+
"""
89+
Generate QAs based on a given batch.
90+
:param batch
91+
:return: QA pairs
92+
"""
93+
rephrasing_prompt = self.build_prompt(batch)
94+
response = await self.llm_client.generate_answer(rephrasing_prompt)
95+
context = self.parse_rephrased_text(response)
96+
if not context:
97+
return []
98+
99+
nodes, edges = batch
100+
101+
assert len(nodes) == 3, (
102+
"MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
103+
f"but got {len(nodes)} nodes."
104+
)
105+
assert len(edges) == 2, (
106+
"MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
107+
f"but got {len(edges)} edges."
108+
)
109+
110+
node1, node2, node3 = nodes
111+
mask_node = random.choice([node1, node2, node3])
112+
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
113+
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
114+
115+
match = re.search(mask_pattern, context)
116+
if match:
117+
gth = match.group(0)
118+
masked_context = mask_pattern.sub("___", context)
119+
else:
120+
logger.debug(
121+
"Regex Match Failed!\n"
122+
"Expected name of node: %s\n"
123+
"Actual context: %s\n",
124+
mask_node_name,
125+
context,
126+
)
127+
return []
128+
129+
logger.debug("masked_context: %s", masked_context)
130+
qa_pairs = {
131+
"question": masked_context,
132+
"answer": gth,
133+
}
134+
return [qa_pairs]

graphgen/models/partitioner/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
from .dfs_partitioner import DFSPartitioner
44
from .ece_partitioner import ECEPartitioner
55
from .leiden_partitioner import LeidenPartitioner
6+
from .quintuple_partitioner import QuintuplePartitioner
7+
from .triple_partitioner import TriplePartitioner
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import random
2+
from collections import deque
3+
from typing import Any, Iterable, Set
4+
5+
from graphgen.bases import BaseGraphStorage, BasePartitioner
6+
from graphgen.bases.datatypes import Community
7+
8+
random.seed(42)
9+
10+
11+
class QuintuplePartitioner(BasePartitioner):
12+
"""
13+
quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
14+
1. Automatically ignore isolated points.
15+
2. In each connected component, yield quintuples in the order of BFS.
16+
"""
17+
18+
def partition(
19+
self,
20+
g: BaseGraphStorage,
21+
**kwargs: Any,
22+
) -> Iterable[Community]:
23+
nodes = [n[0] for n in g.get_all_nodes()]
24+
random.shuffle(nodes)
25+
26+
visited_nodes: Set[str] = set()
27+
used_edges: Set[frozenset[str]] = set()
28+
29+
for seed in nodes:
30+
if seed in visited_nodes:
31+
continue
32+
33+
# start BFS in a connected component
34+
queue = deque([seed])
35+
visited_nodes.add(seed)
36+
37+
while queue:
38+
u = queue.popleft()
39+
40+
# collect all neighbors connected to node u via unused edges
41+
available_neighbors = []
42+
for v in g.get_neighbors(u):
43+
edge_key = frozenset((u, v))
44+
if edge_key not in used_edges:
45+
available_neighbors.append(v)
46+
47+
# standard BFS queue maintenance
48+
if v not in visited_nodes:
49+
visited_nodes.add(v)
50+
queue.append(v)
51+
52+
random.shuffle(available_neighbors)
53+
54+
# every two neighbors paired with the center node u creates one quintuple
55+
# Note: If available_neighbors has an odd length, the remaining edge
56+
# stays unused for now. It may be matched into a quintuple later
57+
# when its other endpoint is processed as a center node.
58+
for i in range(0, len(available_neighbors) // 2 * 2, 2):
59+
v1 = available_neighbors[i]
60+
v2 = available_neighbors[i + 1]
61+
62+
edge1 = frozenset((u, v1))
63+
edge2 = frozenset((u, v2))
64+
65+
used_edges.add(edge1)
66+
used_edges.add(edge2)
67+
68+
v1_s, v2_s = sorted((v1, v2))
69+
70+
yield Community(
71+
id=f"{v1_s}-{u}-{v2_s}",
72+
nodes=[v1_s, u, v2_s],
73+
edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
74+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import random
2+
from collections import deque
3+
from typing import Any, Iterable, Set
4+
5+
from graphgen.bases import BaseGraphStorage, BasePartitioner
6+
from graphgen.bases.datatypes import Community
7+
8+
random.seed(42)
9+
10+
11+
class TriplePartitioner(BasePartitioner):
12+
"""
13+
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
14+
1. Automatically ignore isolated points.
15+
2. In each connected component, yield triples in the order of BFS.
16+
"""
17+
18+
def partition(
19+
self,
20+
g: BaseGraphStorage,
21+
**kwargs: Any,
22+
) -> Iterable[Community]:
23+
nodes = [n[0] for n in g.get_all_nodes()]
24+
random.shuffle(nodes)
25+
26+
visited_nodes: Set[str] = set()
27+
used_edges: Set[frozenset[str]] = set()
28+
29+
for seed in nodes:
30+
if seed in visited_nodes:
31+
continue
32+
33+
# start BFS in a connected component
34+
queue = deque([seed])
35+
visited_nodes.add(seed)
36+
37+
while queue:
38+
u = queue.popleft()
39+
40+
for v in g.get_neighbors(u):
41+
edge_key = frozenset((u, v))
42+
43+
# if this edge has not been used, a new triple has been found
44+
if edge_key not in used_edges:
45+
used_edges.add(edge_key)
46+
47+
# use the edge name to ensure the uniqueness of the ID
48+
u_sorted, v_sorted = sorted((u, v))
49+
yield Community(
50+
id=f"{u_sorted}-{v_sorted}",
51+
nodes=[u_sorted, v_sorted],
52+
edges=[(u_sorted, v_sorted)],
53+
)
54+
55+
# continue to BFS
56+
if v not in visited_nodes:
57+
visited_nodes.add(v)
58+
queue.append(v)

0 commit comments

Comments
 (0)