|
| 1 | +import json |
1 | 2 | import re |
2 | 3 | from typing import Any |
3 | 4 |
|
@@ -75,62 +76,46 @@ async def generate( |
75 | 76 | nodes, _ = batch |
76 | 77 | for node in nodes: |
77 | 78 | node_data = node[1] |
78 | | - if "image_data" in node_data and node_data["image_data"]: |
79 | | - img_path = node_data["image_data"]["img_path"] |
| 79 | + if "metadata" in node_data and node_data["metadata"]: |
| 80 | + metadata = json.loads(node_data["metadata"])["metadata"] |
| 81 | + img_path = metadata.get("path", "") |
80 | 82 | for qa in qa_pairs: |
81 | 83 | qa["img_path"] = img_path |
82 | 84 | return qa_pairs |
83 | 85 |
|
84 | 86 | @staticmethod |
85 | | - def format_generation_results( |
86 | | - result: list[dict], output_data_format: str |
87 | | - ) -> list[dict[str, Any]]: |
| 87 | + def format_generation_results(result: dict, output_data_format: str) -> dict: |
| 88 | + question = result.get("question", "") |
| 89 | + answer = result.get("answer", "") |
| 90 | + img_path = result.get("img_path", "") |
88 | 91 | if output_data_format == "Alpaca": |
89 | | - result = [ |
90 | | - { |
91 | | - "instruction": v["question"], |
92 | | - "input": "", |
93 | | - "output": v["answer"], |
94 | | - "image": v.get("img_path", ""), |
95 | | - } |
96 | | - for item in result |
97 | | - for k, v in item.items() |
98 | | - ] |
99 | | - elif output_data_format == "Sharegpt": |
100 | | - result = [ |
101 | | - { |
102 | | - "conversations": [ |
103 | | - { |
104 | | - "from": "human", |
105 | | - "value": [ |
106 | | - {"text": v["question"], "image": v.get("img_path", "")} |
107 | | - ], |
108 | | - }, |
109 | | - {"from": "gpt", "value": [{"text": v["answer"]}]}, |
110 | | - ] |
111 | | - } |
112 | | - for item in result |
113 | | - for k, v in item.items() |
114 | | - ] |
115 | | - elif output_data_format == "ChatML": |
116 | | - result = [ |
117 | | - { |
118 | | - "messages": [ |
119 | | - { |
120 | | - "role": "user", |
121 | | - "content": [ |
122 | | - {"text": v["question"], "image": v.get("img_path", "")} |
123 | | - ], |
124 | | - }, |
125 | | - { |
126 | | - "role": "assistant", |
127 | | - "content": [{"type": "text", "text": v["answer"]}], |
128 | | - }, |
129 | | - ] |
130 | | - } |
131 | | - for item in result |
132 | | - for k, v in item.items() |
133 | | - ] |
134 | | - else: |
135 | | - raise ValueError(f"Unknown output data format: {output_data_format}") |
136 | | - return result |
| 92 | + return { |
| 93 | + "instruction": question, |
| 94 | + "input": "", |
| 95 | + "output": answer, |
| 96 | + "image": img_path, |
| 97 | + } |
| 98 | + if output_data_format == "Sharegpt": |
| 99 | + return { |
| 100 | + "conversations": [ |
| 101 | + { |
| 102 | + "from": "human", |
| 103 | + "value": [{"text": question, "image": img_path}], |
| 104 | + }, |
| 105 | + {"from": "gpt", "value": [{"text": answer}]}, |
| 106 | + ] |
| 107 | + } |
| 108 | + if output_data_format == "ChatML": |
| 109 | + return { |
| 110 | + "messages": [ |
| 111 | + { |
| 112 | + "role": "user", |
| 113 | + "content": [{"text": question, "image": img_path}], |
| 114 | + }, |
| 115 | + { |
| 116 | + "role": "assistant", |
| 117 | + "content": [{"type": "text", "text": answer}], |
| 118 | + }, |
| 119 | + ] |
| 120 | + } |
| 121 | + raise ValueError(f"Unknown output data format: {output_data_format}") |
0 commit comments