1- import re
21import random
2+ import re
33from typing import Any , Optional
44
55from graphgen .bases import BaseGenerator
88
99random .seed (42 )
1010
11+
1112class MaskedFillInBlankGenerator (BaseGenerator ):
1213 """
1314 Masked Fill-in-blank Generator follows a TWO-STEP process:
@@ -94,18 +95,22 @@ async def generate(
9495 context = self .parse_rephrased_text (response )
9596 if not context :
9697 return []
97-
98+
9899 nodes , edge = batch
99- assert len (nodes ) == 2 , "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
100- assert len (edge ) == 1 , "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
100+ assert (
101+ len (nodes ) == 2
102+ ), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
103+ assert (
104+ len (edge ) == 1
105+ ), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
101106
102107 node1 , node2 = nodes
103108 mask_node = random .choice ([node1 , node2 ])
104- mask_node_name = mask_node [1 ]["entity_name" ].strip (' \' " \n \r \t ' )
109+ mask_node_name = mask_node [1 ]["entity_name" ].strip ("' \ " \n \r \t " )
105110
106111 mask_pattern = re .compile (re .escape (mask_node_name ), re .IGNORECASE )
107112 masked_context = mask_pattern .sub ("___" , context )
108- # For accuracy, extract the actual replaced text from the context as the ground truth (keeping the original case)
113+ # For accuracy, extract the actual replaced text from the context as the ground truth
109114 gth = re .search (mask_pattern , context ).group (0 )
110115
111116 logger .debug ("masked_context: %s" , masked_context )
@@ -114,4 +119,3 @@ async def generate(
114119 "answer" : gth ,
115120 }
116121 return [qa_pairs ]
117-
0 commit comments