Skip to content

Commit

Permalink
support new TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT (#387)
Browse files Browse the repository at this point in the history
* support new TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT

* add default sub tokenizer and improve logic

* simplified ifs

* adding comments for the default sub-tokenizer

---------

Co-authored-by: Ben Shapira <[email protected]>
  • Loading branch information
bensha6757 and Ben Shapira authored Dec 15, 2024
1 parent 788b374 commit f00a714
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
65 changes: 44 additions & 21 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,35 @@ class InjectorToModularTokenizerLib:
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
for text following <@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
example usage:
encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
for embeddings from dict:
encoder_input:
<@TOKENIZER-TYPE=AA><BIOT5_TASK_ID><1><8><SENTINEL_ID_0><@TOKENIZER-TYPE=AA><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><@TOKENIZER-TYPE=EXTERNAL_EMBEDDINGS_FROM_DICT>{protein1_key}<@TOKENIZER-TYPE=AA@MAX-LEN={max_len_1}><SEQUENCE_NATURAL_START>{protein_seq_1}<SEQUENCE_NATURAL_END><EOS>
"""

@staticmethod
def build_placeholder_meta_tokenization(
*,
sequence: Union[str, list, tuple],
sample_dict: Optional[NDict] = None,
default_sub_tokenizer_name: str = "AA",
) -> Tuple[str, List[str]]:
"""
In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts,
we put placeholders to make sure that the total size is known/fixed and respects the meta instructions to the modular tokenizer
default_sub_tokenizer_name: Specifies the name of the default sub-tokenizer. This tokenizer is used for handling special tokens, such as <SCALAR> and <EMBEDDINGS>.
Returns: a tuple with 2 elements
(
a single string with the full query containing placeholder tokens for FLOAT and VECTOR meta tokenizer parts,
Expand Down Expand Up @@ -88,8 +98,8 @@ def build_placeholder_meta_tokenization(
):
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <SCALAR> from it

if tokenizer_type == "SCALARS_LITERALS":
values = subseq.split(",")
Expand All @@ -113,7 +123,11 @@ def build_placeholder_meta_tokenization(
raise Exception(f"tokenizer_type={tokenizer_type} is not supported")

with_placeholders.append(seq)

elif tokenizer_type.startswith("EXTERNAL_EMBEDDINGS_"):
with_placeholders.append(
f"<@TOKENIZER-TYPE={default_sub_tokenizer_name}>"
) # tokenizer selection is arbitrary, we only take the special token <EMBEDDINGS> from it
with_placeholders.append("<EMBEDDINGS>")
elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
else:
Expand All @@ -123,7 +137,7 @@ def build_placeholder_meta_tokenization(
return "".join(with_placeholders), hints_and_subseq

@staticmethod
def build_scalars(
def build_scalars_and_embeddings(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
Expand Down Expand Up @@ -155,6 +169,9 @@ def build_scalars(
# for each element, whether it's a scalar or not
all_scalars_valid_mask = []
scalar_default_unfound_value = -1000.0
external_embeddings_info = dict() # a dict mapping location -> embedding input
num_tokens_token_so_far = 0
num_inputs_needing_embeddings = 0

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
Expand All @@ -173,35 +190,39 @@ def build_scalars(
curr_scalar_values = torch.tensor(
curr_scalar_values, dtype=torch.float32
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_scalar_values = sample_dict[curr_str_data]
assert len(curr_scalar_values.shape) == 1
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_scalar_values)
else:
# prev_index_end += len(curr_placeholder_encoding.ids)
if tokenizer_name == "EXTERNAL_EMBEDDINGS_FROM_DICT":
if sample_dict is None:
raise Exception(
"EXTERNAL_EMBEDDINGS_FROM_DICT used but the provided sample_dict is None"
)
embedding_input = sample_dict[curr_str_data]
external_embeddings_info[num_inputs_needing_embeddings] = (
num_tokens_token_so_far,
embedding_input,
)
num_inputs_needing_embeddings += 1
elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError

curr_scalar_values = torch.full(
(len(curr_placeholder_encoding.ids),),
fill_value=scalar_default_unfound_value,
Expand All @@ -212,6 +233,7 @@ def build_scalars(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)
num_tokens_token_so_far += len(curr_placeholder_encoding.ids)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)
Expand Down Expand Up @@ -255,4 +277,5 @@ def build_scalars(
return {
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars
"external_embeddings_info": external_embeddings_info, # dict - number of input needing embeddings -> (location in the query, embeddings input)
}
15 changes: 13 additions & 2 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def __init__(
verbose=verbose,
**kwargs,
)
# default_sub_tokenizer_name is used as the tokenizer of special tokens such as scalars and external_embeddings tokens.
self.default_sub_tokenizer_name = next(
iter(self._tokenizer.tokenizers_info.values())
)["name"]

def __call__(
self,
Expand All @@ -447,6 +451,7 @@ def __call__(
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars: Optional[str] = None,
key_out_external_embeddings_info: Optional[str] = None,
additional_caller_info_text: Optional[str] = "",
) -> NDict:
"""_summary_
Expand Down Expand Up @@ -480,7 +485,9 @@ def __call__(
with_placeholders_str,
per_meta_orig,
) = InjectorToModularTokenizerLib.build_placeholder_meta_tokenization(
sequence=sample_dict[key_in], sample_dict=sample_dict
sequence=sample_dict[key_in],
sample_dict=sample_dict,
default_sub_tokenizer_name=self.default_sub_tokenizer_name,
)
sample_dict[key_in + ".with_placeholders"] = with_placeholders_str

Expand All @@ -500,7 +507,7 @@ def __call__(
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)

prepared_data = InjectorToModularTokenizerLib.build_scalars(
prepared_data = InjectorToModularTokenizerLib.build_scalars_and_embeddings(
per_meta_tokenizer_data=per_meta_orig,
per_meta_encoding_including_placeholders=sample_dict[
key_in + ".per_meta_part_encoding"
Expand All @@ -514,6 +521,10 @@ def __call__(
sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[
"scalars_valid_mask"
]
if key_out_external_embeddings_info is not None:
sample_dict[key_out_external_embeddings_info] = prepared_data[
"external_embeddings_info"
]

return sample_dict

Expand Down

0 comments on commit f00a714

Please sign in to comment.