Skip to content

Commit

Permalink
Tokenizer's default behavior on unknown token (#375)
Browse files Browse the repository at this point in the history
* done

* support from_pretrained

* switch to default value of 'warn' instead of 'None'

---------

Co-authored-by: Sagi Polaczek <[email protected]>
  • Loading branch information
SagiPolaczek and Sagi Polaczek authored Oct 13, 2024
1 parent 957f37f commit b654810
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
11 changes: 0 additions & 11 deletions fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,12 @@ def add_single_tokenizer(
# we update the special tokens but do not save here. remember to save yourself.
self.update_special_tokens(
special_tokens=new_tokenize_special_tokens,
# save_tokenizer_path=self.cfg_raw["data"]["tokenizer"]["out_path"],
)

def add_tokenizers(
self,
) -> None:
raise Exception("Not implemented")
# self.build_inner_decoder()
# if self._max_possible_token_id is not None:
# if self._get_max_mapped_id() > self._max_possible_token_id:
# raise Exception(
# f"tokenizer remapping resulted in IDs greater (max_id={self._get_max_mapped_id()}) than max_possible_id ({self._max_possible_token_id}). Reinitialize the modular tokenizer with larger max_possible_id"
# )

def _encode_single_type(
self,
Expand Down Expand Up @@ -1059,10 +1052,6 @@ def encode_list(
merged_encoding = Encoding.merge(encoded_list)

max_len = self.get_expected_max_len(override_max_len=max_len)
# if max_len is None:
# if self.max_len is not None:
# max_len = self.max_len

if max_len is not None:
if len(merged_encoding) > max_len:
overflow_info += f"OVERALL:{len(merged_encoding)}=>{max_len}|"
Expand Down
21 changes: 17 additions & 4 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
validate_ends_with_eos: Optional[bool] = True,
eos: Optional[str] = "<EOS>",
verbose: Optional[bool] = False,
on_unknown_default_value: str = "warn",
**kwargs: Any,
) -> None:
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
validate_ends_with_eos: during encoder request (a _call_ to the op) will make sure that it ends with the provided eos token, and raise exception otherwise.
having an eos (end of sentence) token in the end is useful for multiple scenarios, for example in a generative transformer (like T5 encoder-decoder)
verbose:
on_unknown_default_value: User can define the default behavior of unknown token here in the constructor. In addition, this value can be overwritten in the __call__
"""
super().__init__(**kwargs)

Expand All @@ -60,6 +62,10 @@ def __init__(

self._validate_ends_with_eos = validate_ends_with_eos
self._eos = eos
self._on_unknown_default_value = on_unknown_default_value

if on_unknown_default_value not in ["warn", "raise"]:
raise ValueError(f"Doesn't support {on_unknown_default_value=}!")

if self._validate_ends_with_eos:
eos_id = self._tokenizer.token_to_id(self._eos)
Expand Down Expand Up @@ -211,7 +217,7 @@ def __call__(
key_out_attention_mask: Optional[str] = None,
convert_attention_mask_to_bool: Optional[bool] = True,
max_seq_len: Optional[int] = None,
on_unknown: Optional[str] = "warn",
on_unknown: Optional[str] = None,
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
additional_caller_info_text: Optional[str] = "",
Expand All @@ -230,7 +236,7 @@ def __call__(
key_out_attention_mask (Optional[str], optional): _description_. Defaults to None.
convert_attention_mask_to_bool (Optional[bool], optional): _description_. Defaults to True.
max_seq_len (Optional[int], optional): set maximum sequence len dynamically, used for both padding and truncation.. Defaults to None.
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn".
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn". The default value can be determined in the constructor itself.
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
Expand All @@ -243,7 +249,6 @@ def __call__(
Returns:
NDict: _description_
"""

data = sample_dict[key_in]
if not isinstance(data, (list, str)):
# data is a list of named tuples of type collections.namedtuple("TypedInput", ["input_type", "input_string", "max_len"])
Expand All @@ -263,6 +268,10 @@ def __call__(
f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str end was: {last_seq}"
)

if on_unknown is None:
# Use tokenizer instance default value
on_unknown = self._on_unknown_default_value

if isinstance(data, str):
_ans = self._tokenizer.encode(
data,
Expand Down Expand Up @@ -510,6 +519,7 @@ def from_pretrained(
identifier: str,
pad_token: str = "<PAD>",
max_size: Optional[int] = None,
on_unknown_default_value: str = "warn",
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict] = None,
Expand Down Expand Up @@ -549,7 +559,10 @@ def from_pretrained(
) from e

tokenizer_op = cls(
tokenizer_path=identifier, pad_token=pad_token, max_size=max_size
tokenizer_path=identifier,
pad_token=pad_token,
max_size=max_size,
on_unknown_default_value=on_unknown_default_value,
)
return tokenizer_op

Expand Down

0 comments on commit b654810

Please sign in to comment.