-
Notifications
You must be signed in to change notification settings - Fork 0
/
BPE.py
113 lines (90 loc) · 4.12 KB
/
BPE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# install and import libraries
from collections import Counter, defaultdict
from transformers import AutoTokenizer
class BPE():
"""Byte-Pair Encoding: Subword-based tokenization algorithm."""
def __init__(self, corpus, vocab_size):
"""Initialize BPE tokenizer."""
self.corpus = corpus
self.vocab_size = vocab_size
# pre-tokenize the corpus into words, BERT pre-tokenizer is used here
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.word_freqs = defaultdict(int)
self.splits = {}
self.merges = {}
def train(self):
"""Train BPE tokenizer."""
# compute the frequencies of each word in the corpus
for text in self.corpus:
words_with_offsets = self.tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
new_words = [word for word, offset in words_with_offsets]
for word in new_words:
self.word_freqs[word] += 1
# compute the base vocabulary of all characters in the corpus
alphabet = []
for word in self.word_freqs.keys():
for letter in word:
if letter not in alphabet:
alphabet.append(letter)
alphabet.sort()
# add the special token </w> at the beginning of the vocabulary
vocab = ["</w>"] + alphabet.copy()
# split each word into individual characters before training
self.splits = {word: [c for c in word] for word in self.word_freqs.keys()}
# merge the most frequent pair iteratively until the vocabulary size is reached
while len(vocab) < self.vocab_size:
# compute the frequency of each pair
pair_freqs = self.compute_pair_freqs()
# find the most frequent pair
best_pair = ""
max_freq = None
for pair, freq in pair_freqs.items():
if max_freq is None or max_freq < freq:
best_pair = pair
max_freq = freq
# merge the most frequent pair
self.splits = self.merge_pair(*best_pair)
self.merges[best_pair] = best_pair[0] + best_pair[1]
vocab.append(best_pair[0] + best_pair[1])
return self.merges
def compute_pair_freqs(self):
"""Compute the frequency of each pair."""
pair_freqs = defaultdict(int)
for word, freq in self.word_freqs.items():
split = self.splits[word]
if len(split) == 1:
continue
for i in range(len(split) - 1):
pair = (split[i], split[i + 1])
pair_freqs[pair] += freq
return pair_freqs
def merge_pair(self, a, b):
"""Merge the given pair."""
for word in self.word_freqs:
split = self.splits[word]
if len(split) == 1:
continue
i = 0
while i < len(split) - 1:
if split[i] == a and split[i + 1] == b:
split = split[:i] + [a + b] + split[i + 2 :]
else:
i += 1
self.splits[word] = split
return self.splits
def tokenize(self, text):
"""Tokenize a given text with trained BPE tokenizer (including pre-tokenization, split, and merge)."""
pre_tokenize_result = self.tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
pre_tokenized_text = [word for word, offset in pre_tokenize_result]
splits_text = [[l for l in word] for word in pre_tokenized_text]
for pair, merge in self.merges.items():
for idx, split in enumerate(splits_text):
i = 0
while i < len(split) - 1:
if split[i] == pair[0] and split[i + 1] == pair[1]:
split = split[:i] + [merge] + split[i + 2 :]
else:
i += 1
splits_text[idx] = split
result = sum(splits_text, [])
return result