From e33ac088368274dbcd13bee3f746c19915fd1ac7 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 May 2026 15:24:51 +0200 Subject: [PATCH 1/6] add static smiles tokenizer --- chebai/preprocessing/reader.py | 45 ++++ chebai/preprocessing/smiles_tokenizer.py | 259 +++++++++++++++++++++++ 2 files changed, 304 insertions(+) create mode 100644 chebai/preprocessing/smiles_tokenizer.py diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index cc70be7f..9b0198c3 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -257,6 +257,51 @@ def _back_to_smiles(self, smiles_encoded): return smiles_decoded +class StaticSMILESReader(DataReader): + """ + Data reader for SMILES tokens with a static token set. Atoms are split into 5 components: isotope, element, charge, hydrogens, stereo. + New tokens are not added to the token file, and unknown tokens are mapped to a special index. + """ + + COLLATOR = RaggedCollator + + def __init__(self, *args, **kwargs) -> None: + from chebai.preprocessing.smiles_tokenizer import BasicSmilesTokenizer + + super().__init__(*args, **kwargs) + self.tokenizer = BasicSmilesTokenizer() + + @classmethod + def name(cls) -> str: + """Returns the name of the data reader.""" + return "static_smiles" + + def _read_data(self, raw_data: str) -> Optional[List[int]]: + """Tokenize raw SMILES data using BasicSmilesTokenizer with static vocabulary.""" + try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") + except ValueError as e: + print(f"could not process {raw_data}") + print(f"\tError: {e}") + return None + + try: + smiles = Chem.MolToSmiles(mol, canonical=True) + except Exception as e: + print(f"RDKit failed to canonicalize the SMILES: {raw_data}") + print(f"\t{e}") + return None + + try: + return self.tokenizer.encode(smiles) + except Exception as e: + print(f"could not tokenize {raw_data}") + print(f"\tError: {e}") + return None + + class DeepChemDataReader(ChemDataReader): """ Data reader for chemical data using DeepSMILES tokens. diff --git a/chebai/preprocessing/smiles_tokenizer.py b/chebai/preprocessing/smiles_tokenizer.py new file mode 100644 index 00000000..c9449e9b --- /dev/null +++ b/chebai/preprocessing/smiles_tokenizer.py @@ -0,0 +1,259 @@ +from __future__ import annotations +import re +from typing import List + +from rdkit import Chem + + +def _build_bracket_atoms() -> List[str]: + """Enumerate chemically meaningful bracketed atoms.""" + pt = Chem.GetPeriodicTable() + elements = [pt.GetElementSymbol(i) for i in range(1, 119)] + # aromatic forms used in SMILES + elements += ["c", "n", "o", "s", "p", "b", "te", "se"] + charges = range(-5, 8) + hydrogens = range(9) + stereo = ["None", "@", "@@"] + isotopes = range( + 1, 250 + ) # 228Ra is the heaviest isotope in ChEBI (CHEBI:80505) - leave some space for future additions + + tokens = set() + for el in elements: + tokens.add(f"element_{el}") + for ch in charges: + tokens.add(f"charge_{ch}") + for h in hydrogens: + tokens.add(f"hydrogens_{h}") + for st in stereo: + tokens.add(f"stereo_{st}") + for iso in isotopes: + tokens.add(f"isotope_{iso}") + tokens.add("isotope_None") # for non-isotopic atoms + + return list(tokens) + + +NON_BRACKET_TOKENS = [ + # organic subset elements (unbracketed form) + # "B", "C", "N", "O", "S", "P", "F", "I", "Cl", "Br", + # "b", "c", "n", "o", "s", "p", + # bonds / structure + "(", + ")", + "=", + "#", + "->", + "<-", + ">>", + "-", + "+", + "/", + "\\", + ":", + ".", + "~", + "*", + "$", + "?", + "@", + "@@", + # ring closures: single-digit + *[str(d) for d in range(10)], + # ring closures: %10..%99 + *[f"%{n:02d}" for n in range(10, 100)], +] + + +def _build_default_vocab() -> List[str]: + """Special tokens + non-bracket symbols + bracketed atoms.""" + + brackets = _build_bracket_atoms() + + # de-duplicate while preserving order (specials first) + seen, vocab = set(), [] + for tok in NON_BRACKET_TOKENS + brackets: + if tok not in seen: + seen.add(tok) + vocab.append(tok) + return vocab + + +# change to original regex: added -> and <- for handling dative bonds (we use RDKit-normalised SMILES, this is not a standard SMILES feature) +SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|->|<-|>>?|-|\+|\\|\/|:|~|@@|@|\?|\*|\$|\%[0-9]{2}|[0-9])""" +EMBEDDING_OFFSET = 10 +UNKNOWN_TOKEN_IDX = 3 + + +class BasicSmilesTokenizer(object): + """ + Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. + This tokenizer is to be used when a tokenizer that does not require the transformers library by HuggingFace is required. + + Examples + -------- + >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer + >>> tokenizer = BasicSmilesTokenizer() + >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O")) + ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O'] + + + References + ---------- + .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee + ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction + 1572-1583 DOI: 10.1021/acscentsci.9b00576 + """ + + def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN): + """Constructs a BasicSMILESTokenizer. + + Parameters + ---------- + regex: string + SMILES token regex + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(self.regex_pattern) + + self.vocab = _build_default_vocab() + self.vocab_dict = { + tok: idx + EMBEDDING_OFFSET for idx, tok in enumerate(self.vocab) + } + self.idx_to_token = {idx: tok for tok, idx in self.vocab_dict.items()} + + def _parse_bracket_atom(self, bracket_token: str) -> List[str]: + """ + Parse a bracketed atom token into 5 components: isotope, element, charge, hydrogens, stereo. + + E.g. "[85Kr]" -> ["isotope_85", "element_Kr", "charge_0", "hydrogens_0", "stereo_None"] + """ + atom_str = bracket_token[1:-1] # Remove brackets + + # special case: any atom containing a * is treated as a wildcard (e.g. [3*:0],[1*]) + if "*" in atom_str: + return ["*"] + + isotope = None + element = None + charge = 0 + hydrogens = 0 + stereo = None + + pos = 0 + + # Parse isotope (leading digits) + iso_str = "" + while pos < len(atom_str) and atom_str[pos].isdigit(): + iso_str += atom_str[pos] + pos += 1 + if iso_str: + isotope = iso_str + + # Parse element (1-2 letters) + if pos < len(atom_str) and (atom_str[pos].isupper() or atom_str[pos].islower()): + element = atom_str[pos] + pos += 1 + if pos < len(atom_str) and atom_str[pos].islower(): + element += atom_str[pos] + pos += 1 + + # Parse stereo (@ or @@) + if pos < len(atom_str) and atom_str[pos] == "@": + if pos + 1 < len(atom_str) and atom_str[pos + 1] == "@": + stereo = "@@" + pos += 2 + else: + stereo = "@" + pos += 1 + + # Parse hydrogens (H, H2, H3, H4) + if pos < len(atom_str) and atom_str[pos] == "H": + hydrogens = 1 + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + hydrogens = int(atom_str[pos]) + pos += 1 + + # Parse charge (+, -, +2, -2, +3, -3) + if pos < len(atom_str): + if atom_str[pos] == "+": + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + charge = int(atom_str[pos]) + pos += 1 + else: + charge = 1 + elif atom_str[pos] == "-": + pos += 1 + if pos < len(atom_str) and atom_str[pos].isdigit(): + charge = -int(atom_str[pos]) + pos += 1 + else: + charge = -1 + + # Format and return 5 component tokens + return [ + f"isotope_{isotope if isotope else 'None'}", + f"element_{element}", + f"charge_{charge}", + f"hydrogens_{hydrogens}", + f"stereo_{stereo if stereo else 'None'}", + ] + + def tokenize(self, text): + """Tokenize a SMILES string, breaking bracketed atoms into 5 components. + + Non-bracketed tokens are returned as-is. + Bracketed atoms are decomposed into: isotope, element, charge, hydrogens, stereo. + """ + raw_tokens = [token for token in self.regex.findall(text)] + tokens = [] + + for token in raw_tokens: + if token in NON_BRACKET_TOKENS: + tokens.append(token) + continue + if not (token.startswith("[") and token.endswith("]")): + token = ( + f"[{token}]" # Wrap non-bracket tokens in brackets for uniformity + ) + # Parse bracketed atom into 5 components + components = self._parse_bracket_atom(token) + tokens.extend(components) + + return tokens + + def encode(self, text): + tokens = self.tokenize(text) + return [self.vocab_dict.get(token, UNKNOWN_TOKEN_IDX) for token in tokens] + + def decode(self, token_ids, skip_special_tokens=False): + tokens = [self.idx_to_token.get(idx, "[UNK]") for idx in token_ids] + if skip_special_tokens: + tokens = [tok for tok in tokens if tok not in self.vocab_dict] + return "".join(tokens) + + +# ---- quick self-test ------------------------------------------------------ +if __name__ == "__main__": + # tok = SmilesTokenizer.build_default() + # print(f"Vocab size: {tok.vocab_size}") + tok = BasicSmilesTokenizer() + print(f"Vocab size: {len(tok.vocab)}") + examples = [ + "->[se]", + # "CC(=O)Oc1ccccc1C(=O)O", # aspirin + # "C[C@H](N)C(=O)O", # L-alanine + # "[13CH3]CO", # isotope + # "C1CC2(CCCCC2)CC1", # spiro + # "c1ccc2c(c1)[nH]cn2", # benzimidazole with [nH] + ] + + for s in examples: + tokens = tok.tokenize(s) + ids = tok.encode(s) + print(f"\n{s}") + print(f" tokens: {tokens}") + print(f" ids: {ids}") + print(f" decode: {tok.decode(ids)}") From 17021bfeb6365b00f38d14857f112d255fd4a2ca Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 May 2026 20:04:46 +0200 Subject: [PATCH 2/6] update for PubChem tokens --- chebai/preprocessing/smiles_tokenizer.py | 32 +++++++----------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/chebai/preprocessing/smiles_tokenizer.py b/chebai/preprocessing/smiles_tokenizer.py index c9449e9b..91a41999 100644 --- a/chebai/preprocessing/smiles_tokenizer.py +++ b/chebai/preprocessing/smiles_tokenizer.py @@ -10,13 +10,11 @@ def _build_bracket_atoms() -> List[str]: pt = Chem.GetPeriodicTable() elements = [pt.GetElementSymbol(i) for i in range(1, 119)] # aromatic forms used in SMILES - elements += ["c", "n", "o", "s", "p", "b", "te", "se"] - charges = range(-5, 8) + elements += ["c", "n", "o", "s", "p", "b", "te", "se", "si"] + charges = range(-5, 9) hydrogens = range(9) stereo = ["None", "@", "@@"] - isotopes = range( - 1, 250 - ) # 228Ra is the heaviest isotope in ChEBI (CHEBI:80505) - leave some space for future additions + isotopes = range(1, 300) # [295Og] is the heaviest isotope in PubChem tokens = set() for el in elements: @@ -66,7 +64,7 @@ def _build_bracket_atoms() -> List[str]: def _build_default_vocab() -> List[str]: - """Special tokens + non-bracket symbols + bracketed atoms.""" + """non-bracket symbols + bracketed atoms.""" brackets = _build_bracket_atoms() @@ -87,17 +85,6 @@ def _build_default_vocab() -> List[str]: class BasicSmilesTokenizer(object): """ - Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. - This tokenizer is to be used when a tokenizer that does not require the transformers library by HuggingFace is required. - - Examples - -------- - >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer - >>> tokenizer = BasicSmilesTokenizer() - >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O")) - ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O'] - - References ---------- .. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee @@ -242,12 +229,11 @@ def decode(self, token_ids, skip_special_tokens=False): tok = BasicSmilesTokenizer() print(f"Vocab size: {len(tok.vocab)}") examples = [ - "->[se]", - # "CC(=O)Oc1ccccc1C(=O)O", # aspirin - # "C[C@H](N)C(=O)O", # L-alanine - # "[13CH3]CO", # isotope - # "C1CC2(CCCCC2)CC1", # spiro - # "c1ccc2c(c1)[nH]cn2", # benzimidazole with [nH] + "CC(=O)Oc1ccccc1C(=O)O", # aspirin + "C[C@H](N)C(=O)O", # L-alanine + "[13CH3]CO", # isotope + "C1CC2(CCCCC2)CC1", # spiro + "c1ccc2c(c1)[nH]cn2", # benzimidazole with [nH] ] for s in examples: From a219d45da0c6c0c82aaf5782b88bac4549465197 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 May 2026 20:13:19 +0200 Subject: [PATCH 3/6] correctly reassamble SMILES (as far as possible) --- chebai/preprocessing/smiles_tokenizer.py | 66 ++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/smiles_tokenizer.py b/chebai/preprocessing/smiles_tokenizer.py index 91a41999..d9330c33 100644 --- a/chebai/preprocessing/smiles_tokenizer.py +++ b/chebai/preprocessing/smiles_tokenizer.py @@ -33,9 +33,6 @@ def _build_bracket_atoms() -> List[str]: NON_BRACKET_TOKENS = [ - # organic subset elements (unbracketed form) - # "B", "C", "N", "O", "S", "P", "F", "I", "Cl", "Br", - # "b", "c", "n", "o", "s", "p", # bonds / structure "(", ")", @@ -81,6 +78,9 @@ def _build_default_vocab() -> List[str]: SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|->|<-|>>?|-|\+|\\|\/|:|~|@@|@|\?|\*|\$|\%[0-9]{2}|[0-9])""" EMBEDDING_OFFSET = 10 UNKNOWN_TOKEN_IDX = 3 +ORGANIC_SUBSET = frozenset( + {"B", "C", "N", "O", "P", "S", "F", "Cl", "Br", "I", "b", "c", "n", "o", "s", "p"} +) class BasicSmilesTokenizer(object): @@ -215,11 +215,69 @@ def encode(self, text): tokens = self.tokenize(text) return [self.vocab_dict.get(token, UNKNOWN_TOKEN_IDX) for token in tokens] + def _reassemble_bracket_atom( + self, isotope: str, element: str, charge: int, hydrogens: int, stereo: str + ) -> str: + if ( + isotope == "None" + and charge == 0 + and stereo == "None" + and hydrogens == 0 + and element in ORGANIC_SUBSET + ): + return element + inner = "" + if isotope != "None": + inner += isotope + inner += element + if stereo != "None": + inner += stereo + if hydrogens > 0: + inner += "H" + if hydrogens > 1: + inner += str(hydrogens) + if charge > 0: + inner += "+" + if charge > 1: + inner += str(charge) + elif charge < 0: + inner += "-" + if charge < -1: + inner += str(-charge) + return f"[{inner}]" + def decode(self, token_ids, skip_special_tokens=False): tokens = [self.idx_to_token.get(idx, "[UNK]") for idx in token_ids] if skip_special_tokens: tokens = [tok for tok in tokens if tok not in self.vocab_dict] - return "".join(tokens) + + result = [] + i = 0 + while i < len(tokens): + tok = tokens[i] + if ( + tok.startswith("isotope_") + and i + 4 < len(tokens) + and tokens[i + 1].startswith("element_") + and tokens[i + 2].startswith("charge_") + and tokens[i + 3].startswith("hydrogens_") + and tokens[i + 4].startswith("stereo_") + ): + result.append( + self._reassemble_bracket_atom( + isotope=tok[len("isotope_") :], + element=tokens[i + 1][len("element_") :], + charge=int(tokens[i + 2][len("charge_") :]), + hydrogens=int(tokens[i + 3][len("hydrogens_") :]), + stereo=tokens[i + 4][len("stereo_") :], + ) + ) + i += 5 + else: + result.append(tok) + i += 1 + + return "".join(result) # ---- quick self-test ------------------------------------------------------ From cfbf418ac8dc3a09e9d5dc83e6f9c37090d1548e Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 7 May 2026 10:44:31 +0200 Subject: [PATCH 4/6] make static reader default for chebi --- chebai/preprocessing/datasets/chebi.py | 11 ++++------- chebai/preprocessing/reader.py | 9 +++++---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 698835fd..cfad18a7 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -526,7 +526,7 @@ class ChEBIFromList(_ChEBIDataExtractor): """ - READER = dr.ChemDataReader + READER = dr.StaticSMILESReader def __init__( self, @@ -572,7 +572,7 @@ class ChEBIOverX(_ChEBIDataExtractor): THRESHOLD (None): The threshold for selecting classes. """ - READER: dr.ChemDataReader = dr.ChemDataReader + READER = dr.StaticSMILESReader @property def _name(self) -> str: @@ -791,11 +791,8 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): if __name__ == "__main__": - dataset = ChEBIOver50Partial( - chebi_version=247, - subset="3_STAR", - top_class_id="36700", - external_data_ratio=0.5, + dataset = ChEBIOver50( + chebi_version=251, ) dataset.prepare_data() dataset.setup() diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 9b0198c3..0dae39cd 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -276,12 +276,13 @@ def name(cls) -> str: """Returns the name of the data reader.""" return "static_smiles" - def _read_data(self, raw_data: str) -> Optional[List[int]]: + def _read_data(self, raw_data: str | Chem.Mol) -> Optional[List[int]]: """Tokenize raw SMILES data using BasicSmilesTokenizer with static vocabulary.""" try: - mol = Chem.MolFromSmiles(raw_data.strip()) - if mol is None: - raise ValueError(f"Invalid SMILES: {raw_data}") + if isinstance(raw_data, str): + mol = Chem.MolFromSmiles(raw_data.strip()) + else: + mol = raw_data except ValueError as e: print(f"could not process {raw_data}") print(f"\tError: {e}") From aaf5c6e5d225179325864763b062ccdf3aa5163d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 7 May 2026 10:44:55 +0200 Subject: [PATCH 5/6] change electra vocab size --- configs/model/electra.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index 4427715f..5240e3fb 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -4,7 +4,7 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 4400 + vocab_size: 600 max_position_embeddings: 1800 num_attention_heads: 8 num_hidden_layers: 6 From 8ed618cab95aa3645063cabc5313cc070d166e44 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 7 May 2026 17:12:48 +0200 Subject: [PATCH 6/6] omit default tokens -> shortens overall representation massively --- chebai/preprocessing/smiles_tokenizer.py | 72 ++++++++++++++---------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/chebai/preprocessing/smiles_tokenizer.py b/chebai/preprocessing/smiles_tokenizer.py index d9330c33..58b043cd 100644 --- a/chebai/preprocessing/smiles_tokenizer.py +++ b/chebai/preprocessing/smiles_tokenizer.py @@ -11,9 +11,9 @@ def _build_bracket_atoms() -> List[str]: elements = [pt.GetElementSymbol(i) for i in range(1, 119)] # aromatic forms used in SMILES elements += ["c", "n", "o", "s", "p", "b", "te", "se", "si"] - charges = range(-5, 9) - hydrogens = range(9) - stereo = ["None", "@", "@@"] + charges = [i for i in range(-5, 9) if i != 0] + hydrogens = range(1, 9) + stereo = ["@", "@@"] isotopes = range(1, 300) # [295Og] is the heaviest isotope in PubChem tokens = set() @@ -27,7 +27,6 @@ def _build_bracket_atoms() -> List[str]: tokens.add(f"stereo_{st}") for iso in isotopes: tokens.add(f"isotope_{iso}") - tokens.add("isotope_None") # for non-isotopic atoms return list(tokens) @@ -111,9 +110,13 @@ def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN): def _parse_bracket_atom(self, bracket_token: str) -> List[str]: """ - Parse a bracketed atom token into 5 components: isotope, element, charge, hydrogens, stereo. + Parse a bracketed atom token into its components. - E.g. "[85Kr]" -> ["isotope_85", "element_Kr", "charge_0", "hydrogens_0", "stereo_None"] + When all attributes are at their defaults (charge=0, hydrogens=0, isotope=None, stereo=None), + only the element token is emitted. Otherwise all 5 tokens are emitted. + + E.g. "[N]" -> ["element_N"] + "[85Kr]" -> ["isotope_85", "element_Kr", "charge_0", "hydrogens_0", "stereo_None"] """ atom_str = bracket_token[1:-1] # Remove brackets @@ -179,14 +182,17 @@ def _parse_bracket_atom(self, bracket_token: str) -> List[str]: else: charge = -1 - # Format and return 5 component tokens - return [ - f"isotope_{isotope if isotope else 'None'}", - f"element_{element}", - f"charge_{charge}", - f"hydrogens_{hydrogens}", - f"stereo_{stereo if stereo else 'None'}", - ] + # return element token and optionally isotope, charge, hydrogens, stereo tokens + res = [f"element_{element}"] + if isotope is not None: + res.append(f"isotope_{isotope}") + if charge != 0: + res.append(f"charge_{charge}") + if hydrogens != 0: + res.append(f"hydrogens_{hydrogens}") + if stereo is not None: + res.append(f"stereo_{stereo}") + return res def tokenize(self, text): """Tokenize a SMILES string, breaking bracketed atoms into 5 components. @@ -216,7 +222,12 @@ def encode(self, text): return [self.vocab_dict.get(token, UNKNOWN_TOKEN_IDX) for token in tokens] def _reassemble_bracket_atom( - self, isotope: str, element: str, charge: int, hydrogens: int, stereo: str + self, + element: str, + isotope: str = "None", + charge: int = 0, + hydrogens: int = 0, + stereo: str = "None", ) -> str: if ( isotope == "None" @@ -255,24 +266,26 @@ def decode(self, token_ids, skip_special_tokens=False): i = 0 while i < len(tokens): tok = tokens[i] - if ( - tok.startswith("isotope_") - and i + 4 < len(tokens) - and tokens[i + 1].startswith("element_") - and tokens[i + 2].startswith("charge_") - and tokens[i + 3].startswith("hydrogens_") - and tokens[i + 4].startswith("stereo_") - ): + if tok.startswith("element_"): + i += 1 + add = ["None", 0, 0, "None"] + for idx, additional_token in enumerate( + ["isotope_", "charge_", "hydrogens_", "stereo_"] + ): + if i >= len(tokens): + break + if tokens[i].startswith(additional_token): + add[idx] = tokens[i][len(additional_token) :] + if additional_token in ["charge_", "hydrogens_"]: + add[idx] = int(add[idx]) + i += 1 + result.append( self._reassemble_bracket_atom( - isotope=tok[len("isotope_") :], - element=tokens[i + 1][len("element_") :], - charge=int(tokens[i + 2][len("charge_") :]), - hydrogens=int(tokens[i + 3][len("hydrogens_") :]), - stereo=tokens[i + 4][len("stereo_") :], + tok[len("element_") :], + *add, ) ) - i += 5 else: result.append(tok) i += 1 @@ -292,6 +305,7 @@ def decode(self, token_ids, skip_special_tokens=False): "[13CH3]CO", # isotope "C1CC2(CCCCC2)CC1", # spiro "c1ccc2c(c1)[nH]cn2", # benzimidazole with [nH] + "CC(=O)N[C@@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)[C@H](O)O[C@@H]2CO[C@@H]2O[C@@H](C)[C@@H](O)[C@@H](O)[C@@H]2O)O[C@H](CO)[C@@H](O[C@@H]2O[C@H](CO[C@H]3O[C@H](CO[C@@H]4O[C@H](CO)[C@@H](O[C@@H]5O[C@H](CO)[C@H](O)[C@H](O[C@@H]6O[C@H](CO)[C@@H](O[C@@H]7O[C@H](CO)[C@H](O)[C@H](O[C@]8(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O8)[C@H]7O)[C@H](O)[C@H]6NC(C)=O)[C@H]5O)[C@H](O)[C@H]4NC(C)=O)[C@@H](O)[C@H](O)[C@@H]3O[C@@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@H](O)[C@H](O[C@@H]5O[C@H](CO)[C@@H](O[C@@H]6O[C@H](CO)[C@H](O)[C@H](O[C@]7(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O7)[C@H]6O)[C@H](O)[C@H]5NC(C)=O)[C@H]4O)[C@H](O)[C@H]3NC(C)=O)[C@@H](O)[C@H](O[C@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@@H](O[C@@H]5O[C@H](CO)[C@H](O)[C@H](O[C@]6(C(=O)O)C[C@H](O)[C@@H](NC(C)=O)[C@H]([C@H](O)[C@H](O)CO)O6)[C@H]5O)[C@H](O)[C@H]4NC(C)=O)[C@H](O)[C@@H]3O[C@@H]3O[C@H](CO)[C@@H](O[C@@H]4O[C@H](CO)[C@H](O)[C@H](O)[C@H]4O)[C@H](O)[C@@H]3NC(C)=O)[C@@H]2O)[C@@H]1O", ] for s in examples: