Skip to content

Commit fa30df9

Browse files
authored
fix: correctly set weights when w is initialized (#322)
* fix: correctly set weights when w is initialized * fix tests
1 parent 582da09 commit fa30df9

3 files changed

Lines changed: 36 additions & 3 deletions

File tree

model2vec/train/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818
from model2vec import StaticModel
1919
from model2vec.inference import StaticModelPipeline
2020
from model2vec.train.dataset import TextDataset
21-
from model2vec.train.utils import get_probable_pad_token_id, suppress_lightning_warnings, to_pipeline, train_test_split
21+
from model2vec.train.utils import (
22+
get_probable_pad_token_id,
23+
logit,
24+
suppress_lightning_warnings,
25+
to_pipeline,
26+
train_test_split,
27+
)
2228

2329
logger = logging.getLogger(__name__)
2430

@@ -79,11 +85,15 @@ def __init__(
7985
self.freeze = freeze
8086
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id)
8187
self.head = self.construct_head()
82-
self.w = self.construct_weights() if weights is None else nn.Parameter(weights.float(), requires_grad=True)
88+
self._weights = weights
89+
self.w = self.construct_weights()
8390
self.tokenizer = tokenizer
8491

8592
def construct_weights(self) -> nn.Parameter:
8693
"""Construct the weights for the model."""
94+
if self._weights is not None:
95+
w = logit(self._weights)
96+
return nn.Parameter(w.float(), requires_grad=True)
8797
weights = torch.zeros(len(self.token_mapping))
8898
weights[self.pad_id] = -10_000
8999
return nn.Parameter(weights, requires_grad=not self.freeze)

model2vec/train/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Callable
88

99
import numpy as np
10+
import torch
1011
from sklearn.model_selection import train_test_split as sklearn_split
1112
from sklearn.neural_network import MLPClassifier, MLPRegressor
1213
from sklearn.pipeline import make_pipeline
@@ -111,3 +112,8 @@ class TipFilter(logging.Filter):
111112
def filter(self, record: logging.LogRecord) -> bool:
112113
"""Filter out tip messages from lightning."""
113114
return "💡 Tip" not in record.getMessage()
115+
116+
117+
def logit(x: torch.Tensor) -> torch.Tensor:
118+
"""Invert a sigmoid."""
119+
return -torch.log((1 / x) - 1)

tests/test_trainable.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from model2vec.train.base import BaseFinetuneable
1414
from model2vec.train.dataset import TextDataset
1515
from model2vec.train.similarity import StaticModelForSimilarity
16-
from model2vec.train.utils import get_probable_pad_token_id, train_test_split
16+
from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split
1717

1818

1919
@pytest.mark.parametrize("n_layers", [0, 1, 2, 3])
@@ -74,6 +74,17 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To
7474
assert s.w.shape[0] == mock_vectors.shape[0]
7575

7676

77+
def test_init_classifier_from_model_w(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
78+
"""Test initializion from a static model."""
79+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, weights=np.ones(len(mock_vectors)))
80+
s = StaticModelForClassification.from_static_model(model=model)
81+
assert s._weights is not None
82+
assert torch.all(s._weights == torch.ones(len(mock_vectors)))
83+
w = s.construct_weights()
84+
assert w.shape[0] == mock_vectors.shape[0]
85+
assert torch.all(w == logit(torch.ones(len(mock_vectors))))
86+
87+
7788
def test_pad_token(mock_tokenizer: Tokenizer) -> None:
7889
"""Test initializion from a static model."""
7990
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
@@ -360,3 +371,9 @@ def test_determine_interval() -> None:
360371
)
361372
assert val_check_interval == 100
362373
assert check_val_every_epoch is None
374+
375+
376+
def test_logit() -> None:
377+
"""Test on random data."""
378+
x = torch.arange(10).float() / 10
379+
assert torch.allclose(logit(torch.sigmoid(x)), x, atol=1e-6)

0 commit comments

Comments
 (0)