Skip to content

Commit 10ebc37

Browse files
fix: fix partition service
1 parent 90c0a59 commit 10ebc37

5 files changed

Lines changed: 59 additions & 97 deletions

File tree

graphgen/bases/base_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def store(self, results: list, meta_update: dict):
152152
self.kv_storage.index_done_callback()
153153

154154
@abstractmethod
155-
def process(self, batch: list) -> Tuple[Union[list, Iterable[list]], dict]:
155+
def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]:
156156
"""
157157
Process the input batch and return the result.
158158
:param batch

graphgen/models/generator/vqa_generator.py

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import re
23
from typing import Any
34

@@ -75,62 +76,46 @@ async def generate(
7576
nodes, _ = batch
7677
for node in nodes:
7778
node_data = node[1]
78-
if "image_data" in node_data and node_data["image_data"]:
79-
img_path = node_data["image_data"]["img_path"]
79+
if "metadata" in node_data and node_data["metadata"]:
80+
metadata = json.loads(node_data["metadata"])["metadata"]
81+
img_path = metadata.get("path", "")
8082
for qa in qa_pairs:
8183
qa["img_path"] = img_path
8284
return qa_pairs
8385

8486
@staticmethod
85-
def format_generation_results(
86-
result: list[dict], output_data_format: str
87-
) -> list[dict[str, Any]]:
87+
def format_generation_results(result: dict, output_data_format: str) -> dict:
88+
question = result.get("question", "")
89+
answer = result.get("answer", "")
90+
img_path = result.get("img_path", "")
8891
if output_data_format == "Alpaca":
89-
result = [
90-
{
91-
"instruction": v["question"],
92-
"input": "",
93-
"output": v["answer"],
94-
"image": v.get("img_path", ""),
95-
}
96-
for item in result
97-
for k, v in item.items()
98-
]
99-
elif output_data_format == "Sharegpt":
100-
result = [
101-
{
102-
"conversations": [
103-
{
104-
"from": "human",
105-
"value": [
106-
{"text": v["question"], "image": v.get("img_path", "")}
107-
],
108-
},
109-
{"from": "gpt", "value": [{"text": v["answer"]}]},
110-
]
111-
}
112-
for item in result
113-
for k, v in item.items()
114-
]
115-
elif output_data_format == "ChatML":
116-
result = [
117-
{
118-
"messages": [
119-
{
120-
"role": "user",
121-
"content": [
122-
{"text": v["question"], "image": v.get("img_path", "")}
123-
],
124-
},
125-
{
126-
"role": "assistant",
127-
"content": [{"type": "text", "text": v["answer"]}],
128-
},
129-
]
130-
}
131-
for item in result
132-
for k, v in item.items()
133-
]
134-
else:
135-
raise ValueError(f"Unknown output data format: {output_data_format}")
136-
return result
92+
return {
93+
"instruction": question,
94+
"input": "",
95+
"output": answer,
96+
"image": img_path,
97+
}
98+
if output_data_format == "Sharegpt":
99+
return {
100+
"conversations": [
101+
{
102+
"from": "human",
103+
"value": [{"text": question, "image": img_path}],
104+
},
105+
{"from": "gpt", "value": [{"text": answer}]},
106+
]
107+
}
108+
if output_data_format == "ChatML":
109+
return {
110+
"messages": [
111+
{
112+
"role": "user",
113+
"content": [{"text": question, "image": img_path}],
114+
},
115+
{
116+
"role": "assistant",
117+
"content": [{"type": "text", "text": answer}],
118+
},
119+
]
120+
}
121+
raise ValueError(f"Unknown output data format: {output_data_format}")

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import re
23
from collections import Counter, defaultdict
34
from typing import Dict, List, Tuple
@@ -130,15 +131,25 @@ async def merge_nodes(
130131
set([dp["source_id"] for dp in node_data] + source_ids)
131132
)
132133

133-
node_data = {
134+
node_data_dict = {
134135
"entity_type": entity_type,
135136
"entity_name": entity_name,
136137
"description": description,
137138
"source_id": source_id,
138139
"length": self.tokenizer.count_tokens(description),
139140
}
140-
kg_instance.upsert_node(entity_name, node_data=node_data)
141-
return node_data
141+
142+
if entity_type in ("IMAGE", "TABLE", "FORMULA"):
143+
metadata = next(
144+
(dp["metadata"] for dp in node_data if dp.get("metadata")), None
145+
)
146+
if metadata:
147+
node_data_dict["metadata"] = json.dumps(
148+
metadata, ensure_ascii=False, default=str
149+
)
150+
151+
kg_instance.upsert_node(entity_name, node_data=node_data_dict)
152+
return node_data_dict
142153

143154
async def merge_edges(
144155
self,

graphgen/models/kg_builder/mm_kg_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ async def extract(
7070

7171
entity = await handle_single_entity_extraction(attributes, chunk_id)
7272
if entity is not None:
73+
if entity["entity_type"] == "IMAGE":
74+
entity["metadata"] = chunk.metadata
7375
nodes[entity["entity_name"]].append(entity)
7476
continue
7577

graphgen/operators/partition/partition_service.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
else:
5656
raise ValueError(f"Unsupported partition method: {method}")
5757

58-
def process(self, batch: list) -> Tuple[Iterable[list], dict]:
58+
def process(self, batch: list) -> Tuple[Iterable[dict], dict]:
5959
# this operator does not consume any batch data
6060
# but for compatibility we keep the interface
6161
self.kg_instance.reload()
@@ -68,50 +68,14 @@ def generator():
6868
count = 0
6969
for community in communities:
7070
count += 1
71-
batch = self.partitioner.community2batch(community, g=self.kg_instance)
72-
# batch = self._attach_additional_data_to_node(batch)
71+
b = self.partitioner.community2batch(community, g=self.kg_instance)
7372

7473
result = {
75-
"nodes": batch[0],
76-
"edges": batch[1],
74+
"nodes": b[0],
75+
"edges": b[1],
7776
}
7877
result["_trace_id"] = self.get_trace_id(result)
7978
yield result
8079
logger.info("Total communities partitioned: %d", count)
8180

8281
return generator(), {}
83-
84-
# def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
85-
# """
86-
# Attach additional data from chunk_storage to nodes in the batch.
87-
# :param batch: tuple of (nodes_data, edges_data)
88-
# :return: updated batch with additional data attached to nodes
89-
# """
90-
# nodes_data, edges_data = batch
91-
#
92-
# for node_id, node_data in nodes_data:
93-
# entity_type = (node_data.get("entity_type") or "").lower()
94-
# if not entity_type:
95-
# continue
96-
#
97-
# source_ids = [
98-
# sid.strip()
99-
# for sid in node_data.get("source_id", "").split("<SEP>")
100-
# if sid.strip()
101-
# ]
102-
#
103-
# # Handle images
104-
# if "image" in entity_type:
105-
# image_chunks = [
106-
# data
107-
# for sid in source_ids
108-
# if "image" in sid.lower()
109-
# and (data := self.chunk_storage.get_by_id(sid))
110-
# ]
111-
# if image_chunks:
112-
# # The generator expects a dictionary with an 'img_path' key, not a list of captions.
113-
# # We'll use the first image chunk found for this node.
114-
# node_data["image_data"] = json.loads(image_chunks[0]["content"])
115-
# logger.debug("Attached image data to node %s", node_id)
116-
#
117-
# return nodes_data, edges_data

0 commit comments

Comments
 (0)