Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 75 additions & 132 deletions asmtransformers/asmtransformers/models/asmsentencebert.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,93 @@
from collections.abc import Mapping
from typing import Any

from sentence_transformers import SentenceTransformer
from sentence_transformers.base.modality import InputFormatter
from sentence_transformers.sentence_transformer.modules import Pooling, Transformer
from sentence_transformers.sentence_transformer.modules import Pooling
from torch import nn

from .asmbert import ARM64Tokenizer, ASMBertModel


class ASMSTTransformer(Transformer):
"""Analogous to the sentence-transformers Transformer class,
managing our code transformers and tokenizer.

See ASMSentenceTransformer for an overall description."""
class ASMTransformerModule(nn.Module):
"""Minimal sentence-transformers module for ARM64BERT finetuning."""

def __init__(
self,
model_name_or_path: str,
tokenizer,
*,
max_seq_length: int | None = None,
do_lower_case: bool = False,
model_args: Mapping[str, Any] | None = None,
model_args: dict[str, Any] | None = None,
):
nn.Module.__init__(self)

model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {}))
self.model = model
self.auto_model = model
self.processor = tokenizer
self.transformer_task = 'feature-extraction'
self.backend = 'torch'
self.processing_kwargs = {}
self.track_media_counts = False
self._prompt_length_mapping = {}
self._method_signature_cache = {}
self.model_forward_params = set(self.model.forward.__code__.co_varnames) | {
'input_ids',
'attention_mask',
'token_type_ids',
'inputs_embeds',
'return_dict',
}
self.modality_config = {'text': {'method': 'forward', 'method_output_name': 'last_hidden_state'}}
self.module_output_name = 'token_embeddings'
self.input_formatter = InputFormatter(
model_type=self.config.model_type,
message_format='auto',
processor=self.processor,
)
self.input_formatter.supported_modalities = ['text']
self.config_keys = ['max_seq_length', 'do_lower_case']
self.do_lower_case = do_lower_case

# No max_seq_length set. Try to infer from model
if (
max_seq_length is None
and hasattr(self.model, 'config')
and hasattr(self.model.config, 'max_position_embeddings')
and hasattr(self.tokenizer, 'model_max_length')
):
max_seq_length = min(self.model.config.max_position_embeddings, self.tokenizer.model_max_length)

self.max_seq_length = max_seq_length

self.model.tokenizer = tokenizer
super().__init__()
self.model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {}))
self.tokenizer = ARM64Tokenizer.from_pretrained(model_name_or_path)

self.model.tokenizer = self.tokenizer
self.model.config.tokenizer_class = self.tokenizer.__class__.__name__
self.unpad_inputs = False

@property
def tokenizer(self):
return self.processor

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processor = tokenizer
if hasattr(self, 'model'):
self.model.tokenizer = tokenizer


class ASMSentenceTransformer(SentenceTransformer):
"""Convenience class that allows for easy finetuning and inference for a
semantic code similarity model.

Exposes the same interface as sentence-transformer's class SentenceTransformer
does; particularly the `fit()` method for finetuning and the `encode()` method
for inference.

SentenceTransformer internally composes model and tokenizer classes from the
Hugging Face transformer library, so we do the same for our asmtransformers
models and tokenizer.

Graphically:
("ST.": comes from the sentence-transformers package
"T.": comes from the HuggingFace transformers package)

ASMSentenceTransformer ST.SentenceTransformer
| |
---------------------- -------------------
| | | |
ASMSTTransformer ST.Pooling ST.Transformer ST.Pooling
| |
-------------------- -----------------
| | | |
ASMBertModel ARM64BertTokenizer T.BertModel T.BertTokenizer
"""

@staticmethod
def _build_tokenizer(model_name_or_path):
return ARM64Tokenizer.from_pretrained(model_name_or_path)

@classmethod
def _build_modules(cls, model_name_or_path, model_args=None):
tokenizer = cls._build_tokenizer(model_name_or_path)
embedding_model = ASMSTTransformer(model_name_or_path, tokenizer, model_args=model_args or {})
pooling_model = Pooling(embedding_model.get_embedding_dimension())
return embedding_model, pooling_model

@classmethod
def from_pretrained(cls, model_name_or_path, model_args=None):
embedding_model, pooling_model = cls._build_modules(model_name_or_path, model_args=model_args)
return cls(modules=[embedding_model, pooling_model])

@classmethod
def from_basemodel(cls, base_model_name_or_path, model_args=None):
embedding_model, pooling_model = cls._build_modules(base_model_name_or_path, model_args=model_args)
bert_model = embedding_model.model.base_model

# The jTrans architecture shares weights between positional and word embeddings
# Make sure we have done this properly.
if bert_model.embeddings.position_embeddings is not bert_model.embeddings.word_embeddings:
raise RuntimeError('Word embeddings and position embeddings not shared')

# Now freeze layers, like jTrans. Embedding plus 10 layers is the default, so we'll use that.
for param in bert_model.embeddings.parameters():
param.requires_grad = False

freeze_layer_count = 10
if freeze_layer_count:
for layer in bert_model.encoder.layer[:freeze_layer_count]:
for param in layer.parameters():
param.requires_grad = False
def get_embedding_dimension(self) -> int:
return self.model.config.hidden_size

def preprocess(self, inputs, prompt=None, **kwargs):
if prompt:
inputs = [prompt + text for text in inputs]
return self.tokenizer(inputs, **kwargs)

def tokenize(self, texts, **kwargs):
return self.preprocess(texts, **kwargs)

return cls(modules=[embedding_model, pooling_model])
def forward(self, features: dict[str, Any], **kwargs) -> dict[str, Any]:
model_inputs = {
'input_ids': features['input_ids'],
'attention_mask': features['attention_mask'],
'return_dict': True,
}
if 'token_type_ids' in features:
model_inputs['token_type_ids'] = features['token_type_ids']

outputs = self.model(**model_inputs)
features['token_embeddings'] = outputs.last_hidden_state
return features


def apply_freeze_policy(bert_model, *, freeze_embeddings=True, freeze_layer_count=10):
if freeze_embeddings:
for param in bert_model.embeddings.parameters():
param.requires_grad = False

def encode(self, sentences, *args, normalize_embeddings=True, **kwargs):
# Change the default for normalize_embeddings.
return super().encode(sentences, *args, normalize_embeddings=normalize_embeddings, **kwargs)
if freeze_layer_count:
for layer in bert_model.encoder.layer[:freeze_layer_count]:
for param in layer.parameters():
param.requires_grad = False


def build_finetuning_model(
base_model_name_or_path,
model_args=None,
*,
freeze_embeddings=True,
freeze_layer_count=10,
):
embedding_model = ASMTransformerModule(base_model_name_or_path, model_args=model_args)
pooling_model = Pooling(embedding_model.get_embedding_dimension())
model = SentenceTransformer(modules=[embedding_model, pooling_model])
bert_model = model[0].model.base_model

if bert_model.embeddings.position_embeddings is not bert_model.embeddings.word_embeddings:
raise RuntimeError('Word embeddings and position embeddings not shared')

apply_freeze_policy(
bert_model,
freeze_embeddings=freeze_embeddings,
freeze_layer_count=freeze_layer_count,
)
return model


def __getattr__(name):
if name in {'ASMSentenceTransformer', 'ASMSTTransformer'}:
# Temporary checkpoint compatibility bridge: published ST-format checkpoints
# reference asmtransformers.models.asmsentencebert.ASMSTTransformer in modules.json.
# Delete this after those checkpoints are converted to native embedder format.
from .st_compat import ASMSentenceTransformer, ASMSTTransformer

return {'ASMSentenceTransformer': ASMSentenceTransformer, 'ASMSTTransformer': ASMSTTransformer}[name]
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
58 changes: 58 additions & 0 deletions asmtransformers/asmtransformers/models/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import torch

from asmtransformers.models.asmbert import ARM64Tokenizer, ASMBertModel


class ASMEmbedder:
"""Native inference wrapper for ARM64BERT-style embedding checkpoints."""

def __init__(self, model, tokenizer, *, device=None, normalize_embeddings=True):
self.model = model
self.tokenizer = tokenizer
self.normalize_embeddings = normalize_embeddings
self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
self.model.to(self.device)
self.model.eval()

@classmethod
def from_pretrained(cls, model_name_or_path, *, model_args=None, tokenizer_args=None, device=None):
tokenizer = ARM64Tokenizer.from_pretrained(model_name_or_path, **(tokenizer_args or {}))
model = ASMBertModel.from_pretrained(model_name_or_path, **(model_args or {}))
return cls(model, tokenizer, device=device)

@staticmethod
def mean_pool(token_embeddings, attention_mask):
mask = attention_mask.unsqueeze(-1).to(token_embeddings.dtype)
return (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)

def encode(
self,
sentences,
*,
batch_size=32,
normalize_embeddings=None,
convert_to_numpy=True,
):
single_input = isinstance(sentences, str)
sentences = [sentences] if single_input else list(sentences)
normalize_embeddings = self.normalize_embeddings if normalize_embeddings is None else normalize_embeddings

embeddings = []
with torch.no_grad():
for start in range(0, len(sentences), batch_size):
batch = sentences[start : start + batch_size]
inputs = self.tokenizer(batch)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
outputs = self.model(**inputs)
pooled = self.mean_pool(outputs.last_hidden_state, inputs['attention_mask'])
if normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
embeddings.append(pooled.cpu())

embeddings = torch.cat(embeddings, dim=0)
if single_input:
embeddings = embeddings[0]
if convert_to_numpy:
return embeddings.numpy().astype(np.float32, copy=False)
return embeddings
Loading
Loading