Skip to content

Commit 1a90562

Browse files
committed
优化g2p流程,可以处理多音字,中英混合的情况等等
1 parent 3bdf822 commit 1a90562

4 files changed

Lines changed: 1032 additions & 79 deletions

File tree

projects/llm_framework/main_melotts/src/runner/Lexicon.hpp

Lines changed: 233 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,149 +4,303 @@
44
#include <vector>
55
#include <fstream>
66
#include <unordered_map>
7-
#include <assert.h>
7+
#include <algorithm>
8+
#include <sstream>
9+
#include <cassert>
10+
#include <iostream> // 用于日志输出
811

9-
std::vector<std::string> split (const std::string &s, char delim) {
12+
// 使用引用传参优化split函数,避免不必要的拷贝
13+
std::vector<std::string> split(const std::string &s, char delim) {
1014
std::vector<std::string> result;
11-
std::stringstream ss (s);
15+
std::stringstream ss(s);
1216
std::string item;
13-
while (getline (ss, item, delim)) {
14-
result.push_back (item);
17+
while (getline(ss, item, delim)) {
18+
if (!item.empty()) { // 避免添加空字符串
19+
result.push_back(item);
20+
}
1521
}
1622
return result;
1723
}
1824

1925
class Lexicon {
2026
private:
2127
std::unordered_map<std::string, std::pair<std::vector<int>, std::vector<int>>> lexicon;
28+
size_t max_phrase_length; // 追踪词典中最长的词组长度
29+
std::pair<std::vector<int>, std::vector<int>> unknown_token; // '_'的发音作为未知词的默认值
30+
std::unordered_map<int, std::string> reverse_tokens; // 用于将音素ID转回音素符号,用于日志
2231

2332
public:
24-
Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) {
33+
Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0) {
2534
std::unordered_map<std::string, int> tokens;
35+
36+
// 加载tokens
2637
std::ifstream ifs(tokens_filename);
2738
assert(ifs.is_open());
2839

2940
std::string line;
30-
while ( std::getline(ifs, line) ) {
41+
while (std::getline(ifs, line)) {
3142
auto splitted_line = split(line, ' ');
32-
tokens.insert({splitted_line[0], std::stoi(splitted_line[1])});
43+
if (splitted_line.size() >= 2) {
44+
int token_id = std::stoi(splitted_line[1]);
45+
tokens.insert({splitted_line[0], token_id});
46+
reverse_tokens[token_id] = splitted_line[0]; // 建立反向映射
47+
}
3348
}
3449
ifs.close();
3550

51+
// 加载lexicon
3652
ifs.open(lexicon_filename);
3753
assert(ifs.is_open());
38-
while ( std::getline(ifs, line) ) {
54+
while (std::getline(ifs, line)) {
3955
auto splitted_line = split(line, ' ');
56+
if (splitted_line.empty()) continue;
57+
4058
std::string word_or_phrase = splitted_line[0];
59+
60+
// 更新最长词组长度
61+
auto chars = splitEachChar(word_or_phrase);
62+
max_phrase_length = std::max(max_phrase_length, chars.size());
63+
4164
size_t phone_tone_len = splitted_line.size() - 1;
4265
size_t half_len = phone_tone_len / 2;
4366
std::vector<int> phones, tones;
67+
4468
for (size_t i = 0; i < phone_tone_len; i++) {
4569
auto phone_or_tone = splitted_line[i + 1];
4670
if (i < half_len) {
47-
phones.push_back(tokens[phone_or_tone]);
71+
if (tokens.find(phone_or_tone) != tokens.end()) {
72+
phones.push_back(tokens[phone_or_tone]);
73+
}
4874
} else {
4975
tones.push_back(std::stoi(phone_or_tone));
5076
}
5177
}
5278

53-
lexicon.insert({word_or_phrase, std::make_pair(phones, tones)});
79+
lexicon[word_or_phrase] = std::make_pair(phones, tones);
5480
}
5581

82+
// 添加特殊映射
5683
lexicon[""] = lexicon[""];
5784
lexicon[""] = lexicon[""];
5885

86+
// 添加标点符号
5987
const std::vector<std::string> punctuation{"!", "?", "", ",", ".", "'", "-"};
60-
for (auto p : punctuation) {
61-
int i = tokens[p];
62-
int tone = 0;
63-
lexicon[p] = std::make_pair(std::vector<int>{i}, std::vector<int>{tone});
88+
for (const auto& p : punctuation) {
89+
if (tokens.find(p) != tokens.end()) {
90+
int i = tokens[p];
91+
lexicon[p] = std::make_pair(std::vector<int>{i}, std::vector<int>{0});
92+
}
6493
}
65-
lexicon[" "] = std::make_pair(std::vector<int>{tokens["_"]}, std::vector<int>{0});
94+
95+
// 设置'_'作为未知词的发音
96+
assert(tokens.find("_") != tokens.end()); // 确保tokens中包含"_"
97+
unknown_token = std::make_pair(std::vector<int>{tokens["_"]}, std::vector<int>{0});
98+
99+
// 空格映射到'_'的发音
100+
lexicon[" "] = unknown_token;
101+
102+
// 中文标点转换映射
103+
lexicon[""] = lexicon[","];
104+
lexicon[""] = lexicon["."];
105+
lexicon[""] = lexicon["!"];
106+
lexicon[""] = lexicon["?"];
107+
108+
// 输出词典信息
109+
std::cout << "词典加载完成,包含 " << lexicon.size() << " 个条目,最长词组长度: " << max_phrase_length << std::endl;
66110
}
67111

68-
std::vector<std::string> splitEachChar(const std::string& text)
69-
{
112+
std::vector<std::string> splitEachChar(const std::string& text) {
70113
std::vector<std::string> words;
71-
std::string input(text);
72-
int len = input.length();
114+
int len = text.length();
73115
int i = 0;
74116

75117
while (i < len) {
76-
int next = 1;
77-
if ((input[i] & 0x80) == 0x00) {
78-
// std::cout << "one character: " << input[i] << std::endl;
79-
} else if ((input[i] & 0xE0) == 0xC0) {
80-
next = 2;
81-
// std::cout << "two character: " << input.substr(i, next) << std::endl;
82-
} else if ((input[i] & 0xF0) == 0xE0) {
83-
next = 3;
84-
// std::cout << "three character: " << input.substr(i, next) << std::endl;
85-
} else if ((input[i] & 0xF8) == 0xF0) {
86-
next = 4;
87-
// std::cout << "four character: " << input.substr(i, next) << std::endl;
88-
}
89-
words.push_back(input.substr(i, next));
90-
i += next;
118+
int next = 1;
119+
if ((text[i] & 0x80) == 0x00) {
120+
// ASCII
121+
} else if ((text[i] & 0xE0) == 0xC0) {
122+
next = 2; // 2字节UTF-8
123+
} else if ((text[i] & 0xF0) == 0xE0) {
124+
next = 3; // 3字节UTF-8
125+
} else if ((text[i] & 0xF8) == 0xF0) {
126+
next = 4; // 4字节UTF-8
127+
}
128+
words.push_back(text.substr(i, next));
129+
i += next;
91130
}
92131
return words;
93132
}
94133

95-
bool is_english(std::string s) {
96-
if (s.size() == 1)
97-
return (s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z');
98-
else
99-
return false;
134+
bool is_english(const std::string& s) {
135+
return s.size() == 1 && ((s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z'));
100136
}
101137

102-
std::vector<std::string> merge_english(const std::vector<std::string>& splitted_text) {
103-
std::vector<std::string> words;
138+
// 根据词典中的内容,使用最长匹配算法处理输入文本
139+
void convert(const std::string& text, std::vector<int>& phones, std::vector<int>& tones) {
140+
std::cout << "\n开始处理文本: \"" << text << "\"" << std::endl;
141+
std::cout << "=======匹配结果=======" << std::endl;
142+
std::cout << "单元\t|\t音素\t|\t声调" << std::endl;
143+
std::cout << "-----------------------------" << std::endl;
144+
145+
// 在开头添加'_'边界标记
146+
phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end());
147+
tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end());
148+
std::cout << "<BOS>\t|\t" << phonesToString(unknown_token.first) << "\t|\t"
149+
<< tonesToString(unknown_token.second) << std::endl;
150+
151+
auto chars = splitEachChar(text);
104152
int i = 0;
105-
while (i < splitted_text.size()) {
106-
std::string s;
107-
if (is_english(splitted_text[i])) {
108-
while (i < splitted_text.size()) {
109-
if (!is_english(splitted_text[i])) {
110-
break;
111-
}
112-
s += splitted_text[i];
113-
i++;
153+
154+
while (i < chars.size()) {
155+
// 处理英文单词
156+
if (is_english(chars[i])) {
157+
std::string eng_word;
158+
int start = i;
159+
while (i < chars.size() && is_english(chars[i])) {
160+
eng_word += chars[i++];
114161
}
115-
// to lowercase
116-
std::transform(s.begin(), s.end(), s.begin(),
162+
163+
// 英文转小写
164+
std::string orig_word = eng_word; // 保留原始单词用于日志
165+
std::transform(eng_word.begin(), eng_word.end(), eng_word.begin(),
117166
[](unsigned char c){ return std::tolower(c); });
118-
words.push_back(s);
119-
if (i >= splitted_text.size())
167+
168+
// 如果词典中有这个英文单词,使用它;否则使用'_'的发音
169+
if (lexicon.find(eng_word) != lexicon.end()) {
170+
auto& [eng_phones, eng_tones] = lexicon[eng_word];
171+
phones.insert(phones.end(), eng_phones.begin(), eng_phones.end());
172+
tones.insert(tones.end(), eng_tones.begin(), eng_tones.end());
173+
174+
// 打印匹配信息
175+
std::cout << orig_word << "\t|\t" << phonesToString(eng_phones) << "\t|\t"
176+
<< tonesToString(eng_tones) << std::endl;
177+
} else {
178+
// 未找到单词,使用'_'的发音
179+
phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end());
180+
tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end());
181+
182+
// 打印未匹配信息
183+
std::cout << orig_word << "\t|\t" << phonesToString(unknown_token.first) << " (未匹配)\t|\t"
184+
<< tonesToString(unknown_token.second) << std::endl;
185+
}
186+
continue;
187+
}
188+
// 处理非英文字符(如空格、标点)
189+
std::string c = chars[i++];
190+
if (c == " ") continue; // 跳过空格
191+
// 回退一步,用于最长匹配
192+
i--;
193+
194+
195+
// 最长匹配算法处理中文/日文
196+
bool matched = false;
197+
// 尝试从最长的词组开始匹配
198+
for (size_t len = std::min(max_phrase_length, chars.size() - i); len > 0 && !matched; --len) {
199+
std::string phrase;
200+
for (size_t j = 0; j < len; ++j) {
201+
phrase += chars[i + j];
202+
}
203+
204+
if (lexicon.find(phrase) != lexicon.end()) {
205+
auto& [phrase_phones, phrase_tones] = lexicon[phrase];
206+
phones.insert(phones.end(), phrase_phones.begin(), phrase_phones.end());
207+
tones.insert(tones.end(), phrase_tones.begin(), phrase_tones.end());
208+
209+
// 打印匹配信息
210+
std::cout << phrase << "\t|\t" << phonesToString(phrase_phones) << "\t|\t"
211+
<< tonesToString(phrase_tones) << std::endl;
212+
213+
i += len;
214+
matched = true;
120215
break;
216+
}
121217
}
122-
else {
123-
words.push_back(splitted_text[i]);
124-
i++;
218+
219+
// 如果没有匹配到任何词组,使用'_'的发音
220+
if (!matched) {
221+
std::string c = chars[i++];
222+
std::string s = c;
223+
224+
// 中文标点符号转换
225+
std::string orig_char = s; // 保留原始字符用于日志
226+
if (s == "") s = ",";
227+
else if (s == "") s = ".";
228+
else if (s == "") s = "!";
229+
else if (s == "") s = "?";
230+
231+
// 如果词典中找不到,则使用'_'的发音
232+
if (lexicon.find(s) != lexicon.end()) {
233+
auto& [char_phones, char_tones] = lexicon[s];
234+
phones.insert(phones.end(), char_phones.begin(), char_phones.end());
235+
tones.insert(tones.end(), char_tones.begin(), char_tones.end());
236+
237+
// 打印匹配信息
238+
std::cout << orig_char << "\t|\t" << phonesToString(char_phones) << "\t|\t"
239+
<< tonesToString(char_tones) << std::endl;
240+
} else {
241+
phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end());
242+
tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end());
243+
244+
// 打印未匹配信息
245+
std::cout << orig_char << "\t|\t" << phonesToString(unknown_token.first) << " (未匹配)\t|\t"
246+
<< tonesToString(unknown_token.second) << std::endl;
247+
}
125248
}
126249
}
127-
return words;
250+
251+
// 在末尾添加'_'边界标记
252+
phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end());
253+
tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end());
254+
std::cout << "<EOS>\t|\t" << phonesToString(unknown_token.first) << "\t|\t"
255+
<< tonesToString(unknown_token.second) << std::endl;
256+
257+
// 汇总打印最终结果
258+
std::cout << "\n处理结果汇总:" << std::endl;
259+
std::cout << "原文: " << text << std::endl;
260+
std::cout << "音素: " << phonesToString(phones) << std::endl;
261+
std::cout << "声调: " << tonesToString(tones) << std::endl;
262+
std::cout << "====================" << std::endl;
128263
}
129264

130-
void convert(const std::string& text, std::vector<int>& phones, std::vector<int>& tones) {
131-
auto splitted_text = splitEachChar(text);
132-
auto zh_mix_en = merge_english(splitted_text);
133-
for (auto c : zh_mix_en) {
134-
std::string s{c};
135-
if (s == "")
136-
s = ",";
137-
else if (s == "")
138-
s = ".";
139-
else if (s == "")
140-
s = "!";
141-
else if (s == "")
142-
s = "?";
143-
144-
auto phones_and_tones = lexicon[" "];
145-
if (lexicon.find(s) != lexicon.end()) {
146-
phones_and_tones = lexicon[s];
265+
private:
266+
// 处理单个字符
267+
void processChar(const std::string& c, std::vector<int>& phones, std::vector<int>& tones) {
268+
std::string s = c;
269+
270+
// 中文标点符号转换
271+
if (s == "") s = ",";
272+
else if (s == "") s = ".";
273+
else if (s == "") s = "!";
274+
else if (s == "") s = "?";
275+
276+
// 如果词典中找不到,则使用'_'的发音
277+
auto& phones_and_tones = (lexicon.find(s) != lexicon.end()) ? lexicon[s] : unknown_token;
278+
279+
phones.insert(phones.end(), phones_and_tones.first.begin(), phones_and_tones.first.end());
280+
tones.insert(tones.end(), phones_and_tones.second.begin(), phones_and_tones.second.end());
281+
}
282+
283+
// 将音素ID数组转换为字符串用于日志输出
284+
std::string phonesToString(const std::vector<int>& phones) {
285+
std::string result;
286+
for (auto id : phones) {
287+
if (!result.empty()) result += " ";
288+
if (reverse_tokens.find(id) != reverse_tokens.end()) {
289+
result += reverse_tokens[id];
290+
} else {
291+
result += "<" + std::to_string(id) + ">";
147292
}
148-
phones.insert(phones.end(), phones_and_tones.first.begin(), phones_and_tones.first.end());
149-
tones.insert(tones.end(), phones_and_tones.second.begin(), phones_and_tones.second.end());
150293
}
294+
return result;
295+
}
296+
297+
// 将声调数组转换为字符串用于日志输出
298+
std::string tonesToString(const std::vector<int>& tones) {
299+
std::string result;
300+
for (auto tone : tones) {
301+
if (!result.empty()) result += " ";
302+
result += std::to_string(tone);
303+
}
304+
return result;
151305
}
152-
};
306+
};

0 commit comments

Comments
 (0)