|
13 | 13 | from model2vec.train.base import BaseFinetuneable |
14 | 14 | from model2vec.train.dataset import TextDataset |
15 | 15 | from model2vec.train.similarity import StaticModelForSimilarity |
16 | | -from model2vec.train.utils import get_probable_pad_token_id, train_test_split |
| 16 | +from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split |
17 | 17 |
|
18 | 18 |
|
19 | 19 | @pytest.mark.parametrize("n_layers", [0, 1, 2, 3]) |
@@ -74,6 +74,17 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To |
74 | 74 | assert s.w.shape[0] == mock_vectors.shape[0] |
75 | 75 |
|
76 | 76 |
|
| 77 | +def test_init_classifier_from_model_w(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: |
| 78 | + """Test initializion from a static model.""" |
| 79 | + model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, weights=np.ones(len(mock_vectors))) |
| 80 | + s = StaticModelForClassification.from_static_model(model=model) |
| 81 | + assert s._weights is not None |
| 82 | + assert torch.all(s._weights == torch.ones(len(mock_vectors))) |
| 83 | + w = s.construct_weights() |
| 84 | + assert w.shape[0] == mock_vectors.shape[0] |
| 85 | + assert torch.all(w == logit(torch.ones(len(mock_vectors)))) |
| 86 | + |
| 87 | + |
77 | 88 | def test_pad_token(mock_tokenizer: Tokenizer) -> None: |
78 | 89 | """Test initializion from a static model.""" |
79 | 90 | tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer) |
@@ -360,3 +371,9 @@ def test_determine_interval() -> None: |
360 | 371 | ) |
361 | 372 | assert val_check_interval == 100 |
362 | 373 | assert check_val_every_epoch is None |
| 374 | + |
| 375 | + |
| 376 | +def test_logit() -> None: |
| 377 | + """Test on random data.""" |
| 378 | + x = torch.arange(10).float() / 10 |
| 379 | + assert torch.allclose(logit(torch.sigmoid(x)), x, atol=1e-6) |
0 commit comments