Skip to content

Commit bcfd551

Browse files
committed
Update classifier.py
1 parent 71be238 commit bcfd551

1 file changed

Lines changed: 30 additions & 18 deletions

File tree

src/adaptive_classifier/classifier.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,55 +83,67 @@ def add_examples(self, texts: List[str], labels: List[str]):
8383
raise ValueError("Empty input lists")
8484
if len(texts) != len(labels):
8585
raise ValueError("Mismatched text and label lists")
86-
86+
87+
# Check if classifier has any existing classes (before updating mappings)
88+
has_existing_classes = len(self.label_to_id) > 0
89+
8790
# Check for new classes
8891
new_classes = set(labels) - set(self.label_to_id.keys())
8992
is_adding_new_classes = len(new_classes) > 0
90-
93+
9194
# Update label mappings - sort new classes alphabetically for consistent IDs
9295
for label in sorted(new_classes):
9396
idx = len(self.label_to_id)
9497
self.label_to_id[label] = idx
9598
self.id_to_label[idx] = label
96-
99+
97100
# Get embeddings for all texts
98101
embeddings = self._get_embeddings(texts)
99-
102+
100103
# Add examples to memory and update training history
101104
for text, embedding, label in zip(texts, embeddings, labels):
102105
example = Example(text, label, embedding)
103106
self.memory.add_example(example, label)
104-
107+
105108
# Update training history
106109
if label not in self.training_history:
107110
self.training_history[label] = 0
108111
self.training_history[label] += 1
109-
110-
# Special handling for new classes
111-
if is_adding_new_classes:
112+
113+
# Determine training strategy: only use special new class handling for incremental learning
114+
is_incremental_learning = is_adding_new_classes and has_existing_classes
115+
116+
if is_incremental_learning:
117+
# Adding new classes to existing classifier - use special handling
112118
# Store old head for EWC before modifying structure
113119
old_head = copy.deepcopy(self.adaptive_head) if self.adaptive_head is not None else None
114120

121+
# Expand existing head to accommodate new classes (preserves weights)
122+
num_classes = len(self.label_to_id)
123+
self.adaptive_head.update_num_classes(num_classes)
124+
# Move to correct device after update
125+
self.adaptive_head = self.adaptive_head.to(self.device)
126+
127+
# Train with focus on new classes
128+
self._train_new_classes(old_head, new_classes)
129+
else:
130+
# Initial training or regular updates - use normal training
131+
# Initialize head if needed
115132
if self.adaptive_head is None:
116-
# First time initialization
117133
self._initialize_adaptive_head()
118-
else:
119-
# Expand existing head to accommodate new classes (preserves weights)
134+
elif is_adding_new_classes:
135+
# Edge case: expanding head for new classes but treating as regular training
120136
num_classes = len(self.label_to_id)
121137
self.adaptive_head.update_num_classes(num_classes)
122-
# Move to correct device after update
123138
self.adaptive_head = self.adaptive_head.to(self.device)
124139

125-
# Train with focus on new classes
126-
self._train_new_classes(old_head, new_classes)
127-
else:
128-
# Regular training for existing classes
140+
# Regular training
129141
self._train_adaptive_head()
130-
142+
131143
# Strategic training step if enabled
132144
if self.strategic_mode and self.train_steps % self.config.strategic_training_frequency == 0:
133145
self._perform_strategic_training()
134-
146+
135147
# Ensure FAISS index is up to date after adding examples
136148
self.memory._rebuild_index()
137149

0 commit comments

Comments
 (0)