@@ -313,15 +313,15 @@ def map(self, element):
313313class SFTPromptMaskingVision (grain .MapTransform ):
314314 """SFT prompt masking for multimodal"""
315315
316- def __init__ (self , query_column , response_column , max_target_length , unk_id ):
316+ def __init__ (self , query_column , response_column , max_target_length , pad_id ):
317317 self .query_column = query_column
318318 self .response_column = response_column
319319 self .max_target_length = max_target_length
320- self .unk_id = unk_id
320+ self .pad_id = pad_id
321321
322322 def map (self , element ):
323323 inputs = np .concatenate ((element [self .query_column ], element [self .response_column ]))
324- targets = np .concatenate ((np .asarray ([self .unk_id ] * len (element [self .query_column ])), element [self .response_column ]))
324+ targets = np .concatenate ((np .asarray ([self .pad_id ] * len (element [self .query_column ])), element [self .response_column ]))
325325 return {
326326 "inputs" : np .asarray (inputs [: self .max_target_length ], dtype = np .int32 ),
327327 "targets" : np .asarray (targets [: self .max_target_length ], dtype = np .int32 ),
@@ -559,13 +559,13 @@ def __init__(
559559 self ,
560560 max_length : int ,
561561 pad_id : int = 0 ,
562- model_name : str | None = None ,
562+ config = None ,
563563 add_true_length : bool = False ,
564564 max_num_images_per_example : int = - 1 ,
565565 ):
566566 self .max_length = max_length
567567 self .pad_id = pad_id
568- self .model_name = model_name
568+ self .config = config
569569 self .add_true_length = add_true_length
570570 self .max_num_images_per_example = max_num_images_per_example
571571
@@ -614,7 +614,7 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -
614614 raise ValueError ("Input preprocessed_image must have pixel_values to pad images." )
615615
616616 # Determine the maximum number of images/masks allowed.
617- image_offsets = mm_processor .get_image_offsets (self .model_name , preprocessed_image )
617+ image_offsets = mm_processor .get_image_offsets (self .config , preprocessed_image )
618618 single_image_offset = image_offsets // preprocessed_image .pixel_values .shape [0 ]
619619
620620 # Reserve space for at least one text token.
@@ -680,7 +680,7 @@ def map(
680680
681681 for key , _ in element .items ():
682682 if key == "images" :
683- if self .model_name is None :
683+ if self .config . model_name is None :
684684 raise ValueError ("model_name must be provided when padding images" )
685685
686686 element ["images" ] = self ._pad_image_and_mask (element ["images" ])
0 commit comments