diff --git a/asmtransformers/asmtransformers/models/asmsentencebert.py b/asmtransformers/asmtransformers/models/asmsentencebert.py index fdb53f4..3c505c4 100644 --- a/asmtransformers/asmtransformers/models/asmsentencebert.py +++ b/asmtransformers/asmtransformers/models/asmsentencebert.py @@ -1,150 +1,93 @@ -from collections.abc import Mapping from typing import Any from sentence_transformers import SentenceTransformer -from sentence_transformers.base.modality import InputFormatter -from sentence_transformers.sentence_transformer.modules import Pooling, Transformer +from sentence_transformers.sentence_transformer.modules import Pooling from torch import nn from .asmbert import ARM64Tokenizer, ASMBertModel -class ASMSTTransformer(Transformer): - """Analogous to the sentence-transformers Transformer class, - managing our code transformers and tokenizer. - - See ASMSentenceTransformer for an overall description.""" +class ASMTransformerModule(nn.Module): + """Minimal sentence-transformers module for ARM64BERT finetuning.""" def __init__( self, model_name_or_path: str, - tokenizer, *, - max_seq_length: int | None = None, - do_lower_case: bool = False, - model_args: Mapping[str, Any] | None = None, + model_args: dict[str, Any] | None = None, ): - nn.Module.__init__(self) - - model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {})) - self.model = model - self.auto_model = model - self.processor = tokenizer - self.transformer_task = 'feature-extraction' - self.backend = 'torch' - self.processing_kwargs = {} - self.track_media_counts = False - self._prompt_length_mapping = {} - self._method_signature_cache = {} - self.model_forward_params = set(self.model.forward.__code__.co_varnames) | { - 'input_ids', - 'attention_mask', - 'token_type_ids', - 'inputs_embeds', - 'return_dict', - } - self.modality_config = {'text': {'method': 'forward', 'method_output_name': 'last_hidden_state'}} - self.module_output_name = 'token_embeddings' - self.input_formatter = InputFormatter( - model_type=self.config.model_type, - message_format='auto', - processor=self.processor, - ) - self.input_formatter.supported_modalities = ['text'] - self.config_keys = ['max_seq_length', 'do_lower_case'] - self.do_lower_case = do_lower_case - - # No max_seq_length set. Try to infer from model - if ( - max_seq_length is None - and hasattr(self.model, 'config') - and hasattr(self.model.config, 'max_position_embeddings') - and hasattr(self.tokenizer, 'model_max_length') - ): - max_seq_length = min(self.model.config.max_position_embeddings, self.tokenizer.model_max_length) - - self.max_seq_length = max_seq_length - - self.model.tokenizer = tokenizer + super().__init__() + self.model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {})) + self.tokenizer = ARM64Tokenizer.from_pretrained(model_name_or_path) + + self.model.tokenizer = self.tokenizer self.model.config.tokenizer_class = self.tokenizer.__class__.__name__ - self.unpad_inputs = False - - @property - def tokenizer(self): - return self.processor - - @tokenizer.setter - def tokenizer(self, tokenizer): - self.processor = tokenizer - if hasattr(self, 'model'): - self.model.tokenizer = tokenizer - - -class ASMSentenceTransformer(SentenceTransformer): - """Convenience class that allows for easy finetuning and inference for a - semantic code similarity model. - - Exposes the same interface as sentence-transformer's class SentenceTransformer - does; particularly the `fit()` method for finetuning and the `encode()` method - for inference. - - SentenceTransformer internally composes model and tokenizer classes from the - Hugging Face transformer library, so we do the same for our asmtransformers - models and tokenizer. - - Graphically: - ("ST.": comes from the sentence-transformers package - "T.": comes from the HuggingFace transformers package) - - ASMSentenceTransformer ST.SentenceTransformer - | | - ---------------------- ------------------- - | | | | - ASMSTTransformer ST.Pooling ST.Transformer ST.Pooling - | | - -------------------- ----------------- - | | | | - ASMBertModel ARM64BertTokenizer T.BertModel T.BertTokenizer - """ - - @staticmethod - def _build_tokenizer(model_name_or_path): - return ARM64Tokenizer.from_pretrained(model_name_or_path) - - @classmethod - def _build_modules(cls, model_name_or_path, model_args=None): - tokenizer = cls._build_tokenizer(model_name_or_path) - embedding_model = ASMSTTransformer(model_name_or_path, tokenizer, model_args=model_args or {}) - pooling_model = Pooling(embedding_model.get_embedding_dimension()) - return embedding_model, pooling_model - - @classmethod - def from_pretrained(cls, model_name_or_path, model_args=None): - embedding_model, pooling_model = cls._build_modules(model_name_or_path, model_args=model_args) - return cls(modules=[embedding_model, pooling_model]) - - @classmethod - def from_basemodel(cls, base_model_name_or_path, model_args=None): - embedding_model, pooling_model = cls._build_modules(base_model_name_or_path, model_args=model_args) - bert_model = embedding_model.model.base_model - - # The jTrans architecture shares weights between positional and word embeddings - # Make sure we have done this properly. - if bert_model.embeddings.position_embeddings is not bert_model.embeddings.word_embeddings: - raise RuntimeError('Word embeddings and position embeddings not shared') - - # Now freeze layers, like jTrans. Embedding plus 10 layers is the default, so we'll use that. - for param in bert_model.embeddings.parameters(): - param.requires_grad = False - freeze_layer_count = 10 - if freeze_layer_count: - for layer in bert_model.encoder.layer[:freeze_layer_count]: - for param in layer.parameters(): - param.requires_grad = False + def get_embedding_dimension(self) -> int: + return self.model.config.hidden_size + + def preprocess(self, inputs, prompt=None, **kwargs): + if prompt: + inputs = [prompt + text for text in inputs] + return self.tokenizer(inputs, **kwargs) + + def tokenize(self, texts, **kwargs): + return self.preprocess(texts, **kwargs) - return cls(modules=[embedding_model, pooling_model]) + def forward(self, features: dict[str, Any], **kwargs) -> dict[str, Any]: + model_inputs = { + 'input_ids': features['input_ids'], + 'attention_mask': features['attention_mask'], + 'return_dict': True, + } + if 'token_type_ids' in features: + model_inputs['token_type_ids'] = features['token_type_ids'] + + outputs = self.model(**model_inputs) + features['token_embeddings'] = outputs.last_hidden_state + return features + + +def apply_freeze_policy(bert_model, *, freeze_embeddings=True, freeze_layer_count=10): + if freeze_embeddings: + for param in bert_model.embeddings.parameters(): + param.requires_grad = False - def encode(self, sentences, *args, normalize_embeddings=True, **kwargs): - # Change the default for normalize_embeddings. - return super().encode(sentences, *args, normalize_embeddings=normalize_embeddings, **kwargs) + if freeze_layer_count: + for layer in bert_model.encoder.layer[:freeze_layer_count]: + for param in layer.parameters(): + param.requires_grad = False + + +def build_finetuning_model( + base_model_name_or_path, + model_args=None, + *, + freeze_embeddings=True, + freeze_layer_count=10, +): + embedding_model = ASMTransformerModule(base_model_name_or_path, model_args=model_args) + pooling_model = Pooling(embedding_model.get_embedding_dimension()) + model = SentenceTransformer(modules=[embedding_model, pooling_model]) + bert_model = model[0].model.base_model + + if bert_model.embeddings.position_embeddings is not bert_model.embeddings.word_embeddings: + raise RuntimeError('Word embeddings and position embeddings not shared') + + apply_freeze_policy( + bert_model, + freeze_embeddings=freeze_embeddings, + freeze_layer_count=freeze_layer_count, + ) + return model + + +def __getattr__(name): + if name in {'ASMSentenceTransformer', 'ASMSTTransformer'}: + # Temporary checkpoint compatibility bridge: published ST-format checkpoints + # reference asmtransformers.models.asmsentencebert.ASMSTTransformer in modules.json. + # Delete this after those checkpoints are converted to native embedder format. + from .st_compat import ASMSentenceTransformer, ASMSTTransformer + + return {'ASMSentenceTransformer': ASMSentenceTransformer, 'ASMSTTransformer': ASMSTTransformer}[name] + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/asmtransformers/asmtransformers/models/embedder.py b/asmtransformers/asmtransformers/models/embedder.py new file mode 100644 index 0000000..3e47df5 --- /dev/null +++ b/asmtransformers/asmtransformers/models/embedder.py @@ -0,0 +1,58 @@ +import numpy as np +import torch + +from asmtransformers.models.asmbert import ARM64Tokenizer, ASMBertModel + + +class ASMEmbedder: + """Native inference wrapper for ARM64BERT-style embedding checkpoints.""" + + def __init__(self, model, tokenizer, *, device=None, normalize_embeddings=True): + self.model = model + self.tokenizer = tokenizer + self.normalize_embeddings = normalize_embeddings + self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu')) + self.model.to(self.device) + self.model.eval() + + @classmethod + def from_pretrained(cls, model_name_or_path, *, model_args=None, tokenizer_args=None, device=None): + tokenizer = ARM64Tokenizer.from_pretrained(model_name_or_path, **(tokenizer_args or {})) + model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {})) + return cls(model, tokenizer, device=device) + + @staticmethod + def mean_pool(token_embeddings, attention_mask): + mask = attention_mask.unsqueeze(-1).to(token_embeddings.dtype) + return (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9) + + def encode( + self, + sentences, + *, + batch_size=32, + normalize_embeddings=None, + convert_to_numpy=True, + ): + single_input = isinstance(sentences, str) + sentences = [sentences] if single_input else list(sentences) + normalize_embeddings = self.normalize_embeddings if normalize_embeddings is None else normalize_embeddings + + embeddings = [] + with torch.no_grad(): + for start in range(0, len(sentences), batch_size): + batch = sentences[start : start + batch_size] + inputs = self.tokenizer(batch) + inputs = {key: value.to(self.device) for key, value in inputs.items()} + outputs = self.model(**inputs) + pooled = self.mean_pool(outputs.last_hidden_state, inputs['attention_mask']) + if normalize_embeddings: + pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) + embeddings.append(pooled.cpu()) + + embeddings = torch.cat(embeddings, dim=0) + if single_input: + embeddings = embeddings[0] + if convert_to_numpy: + return embeddings.numpy().astype(np.float32, copy=False) + return embeddings diff --git a/asmtransformers/asmtransformers/models/st_compat.py b/asmtransformers/asmtransformers/models/st_compat.py new file mode 100644 index 0000000..b71e9df --- /dev/null +++ b/asmtransformers/asmtransformers/models/st_compat.py @@ -0,0 +1,145 @@ +import json +from pathlib import Path +from typing import Any + +from sentence_transformers import SentenceTransformer +from sentence_transformers.sentence_transformer.modules import Pooling + +from .asmbert import ARM64Tokenizer +from .asmsentencebert import ASMTransformerModule, build_finetuning_model +from .embedder import ASMEmbedder + + +class STCheckpointTransformerModule(ASMTransformerModule): + """Compatibility module for existing sentence-transformers checkpoints. + + Remove this when published embedding checkpoints are converted away from the + old sentence-transformers module format. + """ + + def __init__( + self, + model_name_or_path: str, + *, + max_seq_length: int | None = None, + do_lower_case: bool = False, + model_args: dict[str, Any] | None = None, + processor_kwargs: dict[str, Any] | None = None, + ): + super().__init__(model_name_or_path, model_args=model_args) + if processor_kwargs: + self.tokenizer = ARM64Tokenizer.from_pretrained(model_name_or_path, **processor_kwargs) + self.model.tokenizer = self.tokenizer + self.max_seq_length = max_seq_length or min( + self.model.config.max_position_embeddings, + self.tokenizer.model_max_length, + ) + self.do_lower_case = do_lower_case + + def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None: + output_path = Path(output_path) + self.model.save_pretrained(output_path, safe_serialization=safe_serialization) + self.tokenizer.save_pretrained(output_path) + (output_path / 'sentence_bert_config.json').write_text( + json.dumps( + { + 'max_seq_length': self.max_seq_length, + 'do_lower_case': self.do_lower_case, + } + ) + ) + + @classmethod + def load(cls, model_name_or_path: str, model_kwargs=None, processor_kwargs=None, **kwargs): + config_path = Path(model_name_or_path) / 'sentence_bert_config.json' + config = json.loads(config_path.read_text()) if config_path.exists() else {} + return cls( + model_name_or_path, + max_seq_length=config.get('max_seq_length'), + do_lower_case=config.get('do_lower_case', False), + model_args=model_kwargs or {}, + processor_kwargs=processor_kwargs or {}, + ) + + +# Older sentence-transformers checkpoints reference this exact symbol in modules.json. +ASMSTTransformer = STCheckpointTransformerModule + + +def build_sentence_transformer(model_name_or_path, model_args=None): + embedding_model = STCheckpointTransformerModule(model_name_or_path, model_args=model_args or {}) + pooling_model = Pooling(embedding_model.get_embedding_dimension()) + return _normalize_encode_by_default(SentenceTransformer(modules=[embedding_model, pooling_model])) + + +def _normalize_encode_by_default(model): + encode = model.encode + + def encode_with_normalization_default(sentences, *args, normalize_embeddings=True, **kwargs): + return encode(sentences, *args, normalize_embeddings=normalize_embeddings, **kwargs) + + model.encode = encode_with_normalization_default + return model + + +class ASMSentenceTransformer: + """Compatibility factory for existing callers. + + New code should use build_finetuning_model(). + Remove this when old caller imports are migrated. + """ + + from_pretrained = staticmethod(build_sentence_transformer) + from_basemodel = staticmethod(build_finetuning_model) + + +def load_st_embedding_as_native_embedder(model_name_or_path, *, model_args=None, tokenizer_args=None, device=None): + """Load an old ST-format embedding checkpoint through the native embedder. + + Remove this when published embedding checkpoints are converted to native + ASMBertModel/ARM64Tokenizer checkpoint directories. + """ + + model_path = _resolve_transformer_path(model_name_or_path) + _validate_pooling(model_name_or_path) + return ASMEmbedder.from_pretrained( + model_path, + model_args=model_args, + tokenizer_args=tokenizer_args, + device=device, + ) + + +def _resolve_transformer_path(model_name_or_path): + path = Path(model_name_or_path) + modules_path = path / 'modules.json' + if not modules_path.exists(): + return model_name_or_path + + modules = json.loads(modules_path.read_text()) + transformer = next((module for module in modules if module.get('idx') == 0), None) + if transformer is None: + return model_name_or_path + + transformer_path = transformer.get('path') or '' + return str(path / transformer_path) + + +def _validate_pooling(model_name_or_path): + path = Path(model_name_or_path) + modules_path = path / 'modules.json' + if not modules_path.exists(): + return + + modules = json.loads(modules_path.read_text()) + pooling = next((module for module in modules if 'Pooling' in module.get('type', '')), None) + if pooling is None: + return + + config_path = path / pooling.get('path', '') / 'config.json' + if not config_path.exists(): + return + + config = json.loads(config_path.read_text()) + if config.get('pooling_mode', 'mean') != 'mean': + raise ValueError(f'Unsupported pooling mode: {config["pooling_mode"]}') diff --git a/asmtransformers/docs/architecture.md b/asmtransformers/docs/architecture.md index 818968b..188bcd5 100644 --- a/asmtransformers/docs/architecture.md +++ b/asmtransformers/docs/architecture.md @@ -47,7 +47,10 @@ Model integration lives in [asmtransformers.models.asmbert](../asmtransformers/m There are two layers: - `ASMBertForMaskedLM` and `ASMBertModel` adapt Hugging Face BERT classes to the jTrans-style setup, including shared word/position embeddings and jump-target prediction support during pretraining. -- `ASMSentenceTransformer` adapts the pretrained transformer into a sentence-transformers style embedding model for finetuning and inference. +- `build_finetuning_model()` adapts the pretrained transformer into a plain `SentenceTransformer` model for triplet-loss finetuning. +- `ASMEmbedder` provides native inference without requiring sentence-transformers at deployment time. +- `st_compat` contains temporary compatibility shims for old sentence-transformers checkpoints/imports. +- `asmsentencebert.__getattr__` only exists because old ST checkpoints reference `asmtransformers.models.asmsentencebert.ASMSTTransformer`; remove it after checkpoint conversion. The current tokenizer integration is ARM64-specific: @@ -90,7 +93,7 @@ The current end-to-end flow is: 4. A tokenizer converts the token stream into model inputs with the expected context length. 5. Pretraining uses those inputs for masked language modeling plus jump target prediction. 6. Finetuning wraps the transformer in a sentence-transformers pipeline and optimizes embedding similarity. -7. Inference encodes previously unseen functions into embeddings for downstream similarity search. +7. Inference uses the native embedder to encode previously unseen functions for downstream similarity search. ## What Is ARM64-Specific Today @@ -112,7 +115,8 @@ The following patterns are reusable across instruction sets: - the general CFG-to-token-to-transformer pipeline - the `Preprocessor` hook model for custom operand formatting - Hugging Face BERT wrapping in `ASMBertModel` and `ASMBertForMaskedLM` -- sentence-transformers integration in `ASMSentenceTransformer` +- sentence-transformers finetuning integration via `build_finetuning_model()` +- native embedding inference in `ASMEmbedder` - label-grouped dataset sampling in `LazySentenceLabelDataset` - the script-level workflow stages: preprocess, vocab build, pretrain, finetune, evaluate, infer diff --git a/asmtransformers/scripts/finetune.py b/asmtransformers/scripts/finetune.py index 55a9654..1792f9f 100644 --- a/asmtransformers/scripts/finetune.py +++ b/asmtransformers/scripts/finetune.py @@ -14,7 +14,7 @@ from tzlocal import get_localzone from asmtransformers.datasets import LazySentenceLabelDataset -from asmtransformers.models.asmsentencebert import ASMSentenceTransformer +from asmtransformers.models.asmsentencebert import build_finetuning_model def timestamp(): @@ -69,7 +69,7 @@ def main(data_folder, model, batch_size): handlers=[LoggingHandler(), logging.FileHandler(filename=f'{model_save_path}/training_logging.log')], ) - model = ASMSentenceTransformer.from_basemodel( + model = build_finetuning_model( base_model_name_or_path=model_name_or_path, model_args={'torch_dtype': torch.bfloat16} ) logging.info(f'pre-trained model {model_name} loaded') diff --git a/asmtransformers/scripts/inference.py b/asmtransformers/scripts/inference.py index 94ae5ef..5adb50b 100644 --- a/asmtransformers/scripts/inference.py +++ b/asmtransformers/scripts/inference.py @@ -3,7 +3,7 @@ import torch from datasets import Dataset, concatenate_datasets -from asmtransformers.models.asmsentencebert import ASMSentenceTransformer +from asmtransformers.models.st_compat import load_st_embedding_as_native_embedder device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -14,7 +14,7 @@ def main(data_folder, output_folder, model_path): print('Opening dataset') eval_functions = Dataset.load_from_disk(data_folder) print('Load model') - model = ASMSentenceTransformer.from_pretrained(model_path) + model = load_st_embedding_as_native_embedder(model_path) print('Start creating embeddings') embedded_functions = Dataset.from_dict({'embeddings': model.encode(eval_functions['cfg'])}) embedded_dataset = concatenate_datasets([eval_functions, embedded_functions], axis=1) diff --git a/asmtransformers/tests/test_asmsentencebert_freeze.py b/asmtransformers/tests/test_asmsentencebert_freeze.py new file mode 100644 index 0000000..cb015ba --- /dev/null +++ b/asmtransformers/tests/test_asmsentencebert_freeze.py @@ -0,0 +1,82 @@ +import json + +import pytest +import torch +from transformers import BertConfig + +from asmtransformers.models.asmbert import ASMBertModel +from asmtransformers.models.asmsentencebert import build_finetuning_model + + +@pytest.fixture +def checkpoint_path(tmp_path): + vocab = [f'JUMP_ADDR_{index}' for index in range(512)] + [ + '[PAD]', + '[UNK]', + '[CLS]', + '[SEP]', + '[MASK]', + 'ret', + ] + (tmp_path / 'vocab.txt').write_text('\n'.join(vocab)) + (tmp_path / 'tokenizer_config.json').write_text( + json.dumps( + { + 'do_lower_case': False, + 'do_basic_tokenize': False, + 'tokenize_chinese_chars': False, + 'tokenizer_class': 'ARM64Tokenizer', + 'model_max_length': 512, + } + ) + ) + + torch.manual_seed(0) + config = BertConfig( + vocab_size=len(vocab), + hidden_size=8, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=16, + max_position_embeddings=512, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + pad_token_id=vocab.index('[PAD]'), + ) + model = ASMBertModel(config) + model.save_pretrained(tmp_path) + return tmp_path + + +def test_from_basemodel_freezes_jtrans_default_layers(checkpoint_path): + model = build_finetuning_model(checkpoint_path) + bert_model = model[0].model.base_model + + assert all(not param.requires_grad for param in bert_model.embeddings.parameters()) + assert all(not param.requires_grad for param in bert_model.encoder.layer[0].parameters()) + assert all(not param.requires_grad for param in bert_model.encoder.layer[1].parameters()) + + +def test_from_basemodel_can_disable_freezing(checkpoint_path): + model = build_finetuning_model( + checkpoint_path, + freeze_embeddings=False, + freeze_layer_count=0, + ) + bert_model = model[0].model.base_model + + assert all(param.requires_grad for param in bert_model.embeddings.parameters()) + assert all(param.requires_grad for layer in bert_model.encoder.layer for param in layer.parameters()) + + +def test_from_basemodel_freezes_configured_layer_count(checkpoint_path): + model = build_finetuning_model( + checkpoint_path, + freeze_embeddings=False, + freeze_layer_count=1, + ) + bert_model = model[0].model.base_model + + assert all(param.requires_grad for param in bert_model.embeddings.parameters()) + assert all(not param.requires_grad for param in bert_model.encoder.layer[0].parameters()) + assert all(param.requires_grad for param in bert_model.encoder.layer[1].parameters()) diff --git a/asmtransformers/tests/test_embedder.py b/asmtransformers/tests/test_embedder.py new file mode 100644 index 0000000..8f58c9b --- /dev/null +++ b/asmtransformers/tests/test_embedder.py @@ -0,0 +1,94 @@ +import json + +import numpy as np +import pytest +import torch +from transformers import BertConfig + +from asmtransformers.models.asmbert import ASMBertModel +from asmtransformers.models.embedder import ASMEmbedder + + +@pytest.fixture +def cfg(): + return json.dumps([[4096, ['mov x0,#0x0', 'ret']]]) + + +@pytest.fixture +def checkpoint_path(tmp_path): + vocab = [f'JUMP_ADDR_{index}' for index in range(512)] + [ + '[PAD]', + '[UNK]', + '[CLS]', + '[SEP]', + '[MASK]', + 'mov', + 'x0', + '#0x0', + 'ret', + ] + (tmp_path / 'vocab.txt').write_text('\n'.join(vocab)) + (tmp_path / 'tokenizer_config.json').write_text( + json.dumps( + { + 'do_lower_case': False, + 'do_basic_tokenize': False, + 'tokenize_chinese_chars': False, + 'tokenizer_class': 'ARM64Tokenizer', + 'model_max_length': 512, + } + ) + ) + + torch.manual_seed(0) + config = BertConfig( + vocab_size=len(vocab), + hidden_size=8, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=16, + max_position_embeddings=512, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + pad_token_id=vocab.index('[PAD]'), + ) + model = ASMBertModel(config) + model.save_pretrained(tmp_path) + return tmp_path + + +def test_native_embedder_returns_normalized_embedding(checkpoint_path, cfg): + embedder = ASMEmbedder.from_pretrained(checkpoint_path) + + embedding = embedder.encode(cfg) + + assert embedding.shape == (8,) + assert embedding.dtype == np.float32 + assert np.isclose(np.linalg.norm(embedding), 1.0) + + +def test_native_embedder_batch_size_does_not_change_embeddings(checkpoint_path, cfg): + embedder = ASMEmbedder.from_pretrained(checkpoint_path) + + single = embedder.encode(cfg) + batched = embedder.encode([cfg, cfg], batch_size=2) + + assert np.allclose(single, batched[0]) + assert np.allclose(single, batched[1]) + + +def test_mean_pool_ignores_padding(): + token_embeddings = torch.tensor( + [ + [ + [1.0, 2.0], + [3.0, 4.0], + [100.0, 200.0], + ] + ] + ) + attention_mask = torch.tensor([[1, 1, 0]]) + + pooled = ASMEmbedder.mean_pool(token_embeddings, attention_mask) + + assert torch.equal(pooled, torch.tensor([[2.0, 3.0]])) diff --git a/asmtransformers/tests/test_asmsentencebert.py b/asmtransformers/tests/test_st_compat.py similarity index 59% rename from asmtransformers/tests/test_asmsentencebert.py rename to asmtransformers/tests/test_st_compat.py index 6803558..38e538a 100644 --- a/asmtransformers/tests/test_asmsentencebert.py +++ b/asmtransformers/tests/test_st_compat.py @@ -1,9 +1,62 @@ +import json import os import numpy as np import pytest +import torch +from transformers import BertConfig -from asmtransformers.models.asmsentencebert import ASMSentenceTransformer +from asmtransformers.models.asmbert import ASMBertModel +from asmtransformers.models.embedder import ASMEmbedder +from asmtransformers.models.st_compat import build_sentence_transformer, load_st_embedding_as_native_embedder + + +@pytest.fixture +def cfg(): + return json.dumps([[4096, ['mov x0,#0x0', 'ret']]]) + + +@pytest.fixture +def checkpoint_path(tmp_path): + vocab = [f'JUMP_ADDR_{index}' for index in range(512)] + [ + '[PAD]', + '[UNK]', + '[CLS]', + '[SEP]', + '[MASK]', + 'mov', + 'x0', + '#0x0', + 'ret', + ] + (tmp_path / 'vocab.txt').write_text('\n'.join(vocab)) + (tmp_path / 'tokenizer_config.json').write_text( + json.dumps( + { + 'do_lower_case': False, + 'do_basic_tokenize': False, + 'tokenize_chinese_chars': False, + 'tokenizer_class': 'ARM64Tokenizer', + 'model_max_length': 512, + } + ) + ) + + torch.manual_seed(0) + config = BertConfig( + vocab_size=len(vocab), + hidden_size=8, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=16, + max_position_embeddings=512, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + pad_token_id=vocab.index('[PAD]'), + ) + model = ASMBertModel(config) + model.save_pretrained(tmp_path) + return tmp_path @pytest.fixture @@ -46,7 +99,32 @@ def anchor(): @pytest.fixture(scope='session') def model(): path = 'NetherlandsForensicInstitute/ARM64BERT-embedding' - return ASMSentenceTransformer.from_pretrained(path) + return build_sentence_transformer(path) + + +@pytest.fixture(scope='session') +def native_model(): + path = 'NetherlandsForensicInstitute/ARM64BERT-embedding' + return load_st_embedding_as_native_embedder(path) + + +def test_st_compat_pooling_matches_native_embedder(checkpoint_path, cfg): + native = ASMEmbedder.from_pretrained(checkpoint_path) + sentence_transformer = build_sentence_transformer(checkpoint_path) + + native_embedding = native.encode(cfg) + st_embedding = sentence_transformer.encode(cfg) + + assert np.allclose(native_embedding, st_embedding) + + +def test_st_compat_loader_returns_native_embedder(checkpoint_path, cfg): + embedder = load_st_embedding_as_native_embedder(checkpoint_path) + + embedding = embedder.encode(cfg) + + assert embedding.shape == (8,) + assert np.isclose(np.linalg.norm(embedding), 1.0) @pytest.mark.skipif(os.environ.get('CI') == 'true', reason="don't run this test on CI") @@ -59,6 +137,15 @@ def test_single_embedding(anchor, model): assert np.isclose(embedding.max(), 0.116405316) +@pytest.mark.skipif(os.environ.get('CI') == 'true', reason="don't run this test on CI") +def test_native_embedder_single_embedding(anchor, native_model): + """Native inference should reproduce the published model's golden embedding.""" + embedding = native_model.encode(anchor) + assert np.isclose(embedding.sum(), -0.09272218) + assert np.isclose(embedding.min(), -0.10641833) + assert np.isclose(embedding.max(), 0.116405316) + + @pytest.mark.skipif(os.environ.get('CI') == 'true', reason="don't run this test on CI") def test_compare_identical(anchor, model): """This test ensures that batch size in inference does not affect the embeddings that come out