1+ """Tests for Elastic Weight Consolidation (EWC) module."""
2+
3+ import pytest
4+ import torch
5+ import numpy as np
6+ from adaptive_classifier import AdaptiveClassifier
7+ from adaptive_classifier .ewc import EWC
8+ import torch .nn as nn
9+
10+
11+ @pytest .fixture
12+ def simple_model ():
13+ """Create a simple neural network for testing."""
14+ class SimpleModel (nn .Module ):
15+ def __init__ (self , input_dim = 10 , num_classes = 3 ):
16+ super ().__init__ ()
17+ self .fc = nn .Linear (input_dim , num_classes )
18+
19+ def forward (self , x ):
20+ return self .fc (x )
21+
22+ return SimpleModel ()
23+
24+
25+ @pytest .fixture
26+ def small_dataset ():
27+ """Create a small dataset for testing."""
28+ # Create embeddings and labels
29+ embeddings = torch .randn (33 , 10 ) # 33 samples to test edge case
30+ labels = torch .tensor ([0 , 1 , 2 ] * 11 ) # 3 classes repeated
31+ return torch .utils .data .TensorDataset (embeddings , labels )
32+
33+
34+ def test_ewc_single_batch_edge_case (simple_model , small_dataset ):
35+ """Test EWC with dataset size that creates single-sample batch.
36+
37+ This tests the fix for the squeeze() bug that occurred when
38+ the last batch had only 1 sample.
39+ """
40+ device = 'cpu'
41+
42+ # This should not raise an error anymore
43+ ewc = EWC (
44+ simple_model ,
45+ small_dataset ,
46+ device = device ,
47+ ewc_lambda = 100.0
48+ )
49+
50+ assert ewc is not None
51+ assert ewc .fisher_info is not None
52+ assert ewc .old_params is not None
53+
54+
55+ def test_ewc_various_batch_sizes ():
56+ """Test EWC with various dataset sizes to ensure robustness."""
57+ class SimpleModel (nn .Module ):
58+ def __init__ (self ):
59+ super ().__init__ ()
60+ self .fc = nn .Linear (10 , 3 )
61+
62+ def forward (self , x ):
63+ return self .fc (x )
64+
65+ # Test with different dataset sizes that create different batch scenarios
66+ test_sizes = [1 , 31 , 32 , 33 , 64 , 65 , 100 ] # Various edge cases
67+
68+ for size in test_sizes :
69+ model = SimpleModel ()
70+ embeddings = torch .randn (size , 10 )
71+ labels = torch .randint (0 , 3 , (size ,))
72+ dataset = torch .utils .data .TensorDataset (embeddings , labels )
73+
74+ # Should not raise any errors
75+ ewc = EWC (model , dataset , device = 'cpu' , ewc_lambda = 100.0 )
76+
77+ # Verify EWC was initialized properly
78+ assert ewc .fisher_info is not None
79+ assert len (ewc .fisher_info ) > 0
80+
81+ # Test EWC loss computation
82+ loss = ewc .ewc_loss (batch_size = 32 )
83+ assert loss is not None
84+ assert loss .item () >= 0 # Loss should be non-negative
85+
86+
87+ def test_adaptive_classifier_with_many_classes ():
88+ """Test AdaptiveClassifier with many classes (simulates Banking77 scenario)."""
89+ # Set seed for reproducibility
90+ np .random .seed (42 )
91+ torch .manual_seed (42 )
92+
93+ # Create classifier
94+ classifier = AdaptiveClassifier ('distilbert-base-uncased' , device = 'cpu' )
95+
96+ # Simulate many classes with few examples each
97+ num_classes = 20
98+ examples_per_class = 3
99+
100+ texts = []
101+ labels = []
102+
103+ for class_id in range (num_classes ):
104+ class_name = f"class_{ class_id } "
105+ for example_id in range (examples_per_class ):
106+ texts .append (f"This is example { example_id } for { class_name } " )
107+ labels .append (class_name )
108+
109+ # Add examples in batches (this should trigger EWC when new classes appear)
110+ batch_size = 10
111+ for i in range (0 , len (texts ), batch_size ):
112+ batch_texts = texts [i :i + batch_size ]
113+ batch_labels = labels [i :i + batch_size ]
114+
115+ # This should not raise any errors
116+ classifier .add_examples (batch_texts , batch_labels )
117+
118+ # Verify classifier works
119+ test_text = "This is a test example"
120+ predictions = classifier .predict (test_text , k = 3 )
121+
122+ assert predictions is not None
123+ assert len (predictions ) <= 3
124+ assert all (isinstance (p [0 ], str ) for p in predictions ) # Labels are strings
125+ assert all (isinstance (p [1 ], float ) for p in predictions ) # Scores are floats
126+
127+
128+ def test_ewc_loss_computation (simple_model , small_dataset ):
129+ """Test that EWC loss is computed correctly."""
130+ device = 'cpu'
131+
132+ # Initialize EWC
133+ ewc = EWC (
134+ simple_model ,
135+ small_dataset ,
136+ device = device ,
137+ ewc_lambda = 100.0
138+ )
139+
140+ # Modify model parameters slightly
141+ for param in simple_model .parameters ():
142+ param .data += 0.1
143+
144+ # Compute EWC loss
145+ loss = ewc .ewc_loss ()
146+
147+ # Loss should be positive since we changed parameters
148+ assert loss .item () > 0
149+
150+ # Test with batch size normalization
151+ loss_normalized = ewc .ewc_loss (batch_size = 32 )
152+ assert loss_normalized .item () > 0
153+ assert loss_normalized .item () != loss .item () # Should be different due to normalization
154+
155+
156+ def test_progressive_class_addition ():
157+ """Test adding classes progressively (triggers EWC multiple times)."""
158+ classifier = AdaptiveClassifier ('distilbert-base-uncased' , device = 'cpu' )
159+
160+ # Phase 1: Add initial classes
161+ phase1_texts = ["Good product" , "Bad service" , "Average quality" ]
162+ phase1_labels = ["positive" , "negative" , "neutral" ]
163+ classifier .add_examples (phase1_texts , phase1_labels )
164+
165+ # Phase 2: Add new classes (should trigger EWC)
166+ phase2_texts = ["Need help" , "Bug report" , "Feature request" ]
167+ phase2_labels = ["support" , "bug" , "feature" ]
168+ classifier .add_examples (phase2_texts , phase2_labels )
169+
170+ # Phase 3: Add more examples to existing classes
171+ phase3_texts = ["Excellent!" , "Terrible!" , "It's okay" ]
172+ phase3_labels = ["positive" , "negative" , "neutral" ]
173+ classifier .add_examples (phase3_texts , phase3_labels )
174+
175+ # Phase 4: Add more new classes (should trigger EWC again)
176+ phase4_texts = ["Urgent issue" , "Question about pricing" ]
177+ phase4_labels = ["urgent" , "inquiry" ]
178+ classifier .add_examples (phase4_texts , phase4_labels )
179+
180+ # Verify all classes are learned
181+ expected_classes = {"positive" , "negative" , "neutral" , "support" ,
182+ "bug" , "feature" , "urgent" , "inquiry" }
183+
184+ for label in expected_classes :
185+ assert label in classifier .label_to_id
186+
187+ # Test prediction
188+ test_text = "This is wonderful!"
189+ predictions = classifier .predict (test_text , k = 3 )
190+ assert predictions is not None
191+ assert len (predictions ) > 0
192+
193+
194+ def test_ewc_with_empty_batch_edge_case ():
195+ """Test EWC handles edge cases gracefully."""
196+ class TinyModel (nn .Module ):
197+ def __init__ (self ):
198+ super ().__init__ ()
199+ self .fc = nn .Linear (5 , 2 )
200+
201+ def forward (self , x ):
202+ return self .fc (x )
203+
204+ model = TinyModel ()
205+
206+ # Create a tiny dataset
207+ embeddings = torch .randn (1 , 5 ) # Single sample
208+ labels = torch .tensor ([0 ])
209+ dataset = torch .utils .data .TensorDataset (embeddings , labels )
210+
211+ # Should handle single sample without errors
212+ ewc = EWC (model , dataset , device = 'cpu' , ewc_lambda = 50.0 )
213+
214+ assert ewc is not None
215+ loss = ewc .ewc_loss ()
216+ assert loss .item () >= 0
0 commit comments