@@ -83,55 +83,67 @@ def add_examples(self, texts: List[str], labels: List[str]):
8383 raise ValueError ("Empty input lists" )
8484 if len (texts ) != len (labels ):
8585 raise ValueError ("Mismatched text and label lists" )
86-
86+
87+ # Check if classifier has any existing classes (before updating mappings)
88+ has_existing_classes = len (self .label_to_id ) > 0
89+
8790 # Check for new classes
8891 new_classes = set (labels ) - set (self .label_to_id .keys ())
8992 is_adding_new_classes = len (new_classes ) > 0
90-
93+
9194 # Update label mappings - sort new classes alphabetically for consistent IDs
9295 for label in sorted (new_classes ):
9396 idx = len (self .label_to_id )
9497 self .label_to_id [label ] = idx
9598 self .id_to_label [idx ] = label
96-
99+
97100 # Get embeddings for all texts
98101 embeddings = self ._get_embeddings (texts )
99-
102+
100103 # Add examples to memory and update training history
101104 for text , embedding , label in zip (texts , embeddings , labels ):
102105 example = Example (text , label , embedding )
103106 self .memory .add_example (example , label )
104-
107+
105108 # Update training history
106109 if label not in self .training_history :
107110 self .training_history [label ] = 0
108111 self .training_history [label ] += 1
109-
110- # Special handling for new classes
111- if is_adding_new_classes :
112+
113+ # Determine training strategy: only use special new class handling for incremental learning
114+ is_incremental_learning = is_adding_new_classes and has_existing_classes
115+
116+ if is_incremental_learning :
117+ # Adding new classes to existing classifier - use special handling
112118 # Store old head for EWC before modifying structure
113119 old_head = copy .deepcopy (self .adaptive_head ) if self .adaptive_head is not None else None
114120
121+ # Expand existing head to accommodate new classes (preserves weights)
122+ num_classes = len (self .label_to_id )
123+ self .adaptive_head .update_num_classes (num_classes )
124+ # Move to correct device after update
125+ self .adaptive_head = self .adaptive_head .to (self .device )
126+
127+ # Train with focus on new classes
128+ self ._train_new_classes (old_head , new_classes )
129+ else :
130+ # Initial training or regular updates - use normal training
131+ # Initialize head if needed
115132 if self .adaptive_head is None :
116- # First time initialization
117133 self ._initialize_adaptive_head ()
118- else :
119- # Expand existing head to accommodate new classes (preserves weights)
134+ elif is_adding_new_classes :
135+ # Edge case: expanding head for new classes but treating as regular training
120136 num_classes = len (self .label_to_id )
121137 self .adaptive_head .update_num_classes (num_classes )
122- # Move to correct device after update
123138 self .adaptive_head = self .adaptive_head .to (self .device )
124139
125- # Train with focus on new classes
126- self ._train_new_classes (old_head , new_classes )
127- else :
128- # Regular training for existing classes
140+ # Regular training
129141 self ._train_adaptive_head ()
130-
142+
131143 # Strategic training step if enabled
132144 if self .strategic_mode and self .train_steps % self .config .strategic_training_frequency == 0 :
133145 self ._perform_strategic_training ()
134-
146+
135147 # Ensure FAISS index is up to date after adding examples
136148 self .memory ._rebuild_index ()
137149
0 commit comments