Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,6 @@ def distill_from_model(

# Create the vocabulary in the new tokenizer.
tokenizer_model = clean_and_create_vocabulary(tokenizer_model, vocabulary, token_remove_regex=token_remove_regex)
# Remove the post processor, this is not necessary.
tokenizer_model.post_processor = None
# Prune again now that the post processor is gone.
# We can't do this before because we need the post processor and associated
# tokens before to add eos/bos.
tokenizer_model = tokenizer_model.prune_added_tokens()

# All tokens in a single list.
all_tokens = tokenizer_model.sorted_vocabulary
Expand Down
19 changes: 17 additions & 2 deletions model2vec/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@ def clean_and_create_vocabulary(
vocabulary_to_add: list[str],
token_remove_regex: re.Pattern[str] | None,
) -> TokenizerModel:
"""Clean a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
"""
Clean a vocabulary by removing duplicates and tokens that were already in the vocabulary.

This function removes duplicate tokens and tokens that are already in the model's vocabulary.
It also removes the tokenizer's post-processor, which we do not use anyway.

:param model: The tokenizer model to clean.
:param vocabulary_to_add: The vocabulary to add to the model. Any tokens in this vocabulary that
are split according to the pretokenizer are added as multiword tokens.
:param token_remove_regex: A regex pattern to remove tokens from the vocabulary.
:return: The cleaned tokenizer model.
"""
seen_tokens = set()

n_duplicate = 0
Expand All @@ -39,7 +50,9 @@ def clean_and_create_vocabulary(
if len(preprocessed) > 1:
tokens_as_str = [f"'{subword}'" for subword in preprocessed]
split_into = ",".join(tokens_as_str)
logger.warning(f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}]")
logger.warning(
f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}], adding it as a multi-word token."
)
added_tokens_to_add.append(token)
continue
token = preprocessed[0]
Expand All @@ -54,6 +67,8 @@ def clean_and_create_vocabulary(
seen_tokens.add(token)
tokens_to_add.append(token)

model.post_processor = None
model = model.prune_added_tokens()
model = model.add_tokens_to_vocabulary(tokens_to_add, preprocess_tokens=True)
model = model.add_addedtokens(added_tokens_to_add, is_special=False, single_word=False, normalized=True)

Expand Down
2 changes: 1 addition & 1 deletion model2vec/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_triple__ = (0, 8, 1)
__version_triple__ = (0, 8, 2)
__version__ = ".".join(map(str, __version_triple__))
7 changes: 7 additions & 0 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
(None, 1024, None), # Subword, PCA set high, SIF off
(None, None, 1e-4), # No PCA, SIF on
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
(["star wars"], 8, None), # Multiword vocabulary
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
Expand Down Expand Up @@ -79,6 +80,12 @@ def test_distill_from_model(
assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str())
assert static_model.base_model_name == static_model2.base_model_name

for token in vocabulary or []:
# Normalized tokens are for single-word tokens.
# Other tokens are added as addedtokens, as is.
normalized = static_model.tokenizer.normalizer.normalize_str(token)
assert token in static_model.tokens or normalized in static_model.tokens


@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
Expand Down
Loading