Tokenizer Series Part 2: From BPE Training to a Working Tokenizer Class

In our previous blog, we learned about the fundamentals of tokenization and how byte pair encoding works.

Now that we have created all the building blocks, we are now ready to write our tokenizer class to merge, encode and decode byte pairs as tokens but first still we need couple of more helper functions one to escape control characters such as \n \t the other to render our tokens so all in all along with our older functions the code will look like as follows we encourage you to try it out run the functions on your own in notebook or collab before proceeding to get a feel of where each part belongs in the whole picture.

import unicodedata

#similar as before

def get_stats(ids,counts=None):
counts = {} if counts is None else counts
for pair in zip(ids,ids[1:]):
counts[pair] = counts.get(pair,0) + 1
return counts

def merge(ids,pair,idx):
newids = []
'''
replace recurring patterns eg [1,2,3,1,2] replacing pair (1,2) -> 4 we get [4,3,4]
'''
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i+=2
else:
newids.append(ids[i])
i += 1
return newids

def replace_control_characters(s : str) -> str:
#replacing control characters that may distort o/p like \n \t etc....
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch) #non control characters
else:
chars.append(f"\\u{ord(ch):04x}") #escape control characters
return "".join(chars)


def render_token(t : bytes) -> str:
s = t.decode('utf-8',errors='replace')
s = replace_control_characters(s)
return s








Building the Tokenizer Class

Now finally be can write our tokenizer class

class Tokenizer:

def __init__(self):

    self.merges = {}

    self.pattern = {}

    self.special_tokens = {} # eg &lt;|endoftext|>

    self.vocab = self._build_vocab()

def _build_vocab(self):

    vocab = {idx: bytes([idx]) for idx in range(256)}

    for (p0, p1), idx in self.merges.items():

        vocab[idx] = vocab[p0] + vocab[p1]

    for special, idx in self.special_tokens.items():

        vocab[idx] = special.encode("utf-8")

    return vocab

def train(self,text,vocab_size,verbose=False):

    assert vocab_size >= 256

    num_merges = vocab_size - 256

    text_bytes = text.encode("utf-8")

    ids = list(text_bytes)

    merges = {}

    vocab = {idx: bytes([idx]) for idx in range(256)}

    for i in range(num_merges):

        stats = get_stats(ids)

        pair = max(stats,key=stats.get)

        idx = 256 + i

        ids = merge(ids,pair,idx)

        merges[pair] = idx

        vocab[idx] = vocab[pair[0]]+ vocab[pair[1]]

        if verbose:

            print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

    self.merges = merges #used in encode

    self.vocab = vocab #used in decode

def decode(self,ids):

    text_bytes = b"".join(self.vocab[idx] for idx in ids)

    text = text_bytes.decode("utf-8",errors="replace")

    return text

def encode(self,text):

    text_bytes = text.encode("utf-8")

    ids = list(text_bytes)

    while len(ids) >= 2:

        stats = get_stats(ids)

        pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))

        if pair not in self.merges:

            break # no more merging is possible

        idx = self.merges[pair]

        ids = merge(ids,pair,idx)

    return ids

All this being done we added verbose property while training to see the merges as they happen now you should download a text file of considerable length and save it in the same directory as your notebook/script file and load it and train it with your tokenizer class or you can opt for the text file we used which you can find along with our code here.

Training


text = open("data.txt","r",encoding="utf-8").read()
tokenizer = Tokenizer()
tokenizer.train(text,512,verbose=True)

Here we opted for vocab of 512 you can adjust it as per your wish verbose=True will allow us to see the merges. After running the final line you should expect a result similar as this:

Byte Pair Encoding tokenizer showing merge steps and token frequency counts during BPE training
Byte Pair Encoding tokenizer showing merge steps and token frequency counts during BPE training

And that’s it for our simple tokenizer now we will teach you to save and load model file but this is quite redundant for almost all ML applications and it is required so we recommend you try it on your own first.

Saving and Loading

Keep in mind the code for saving and loading ought to be appended to the tokenizer class.

def save(self,file_prefix):
        '''
        Two files are save file_prefix.vocab and file_prefix.model

        model file is the actually important file 
        vocab is for inspection
        '''
        model_file = file_prefix +'.model'
        with open(model_file,'w') as f:
            f.write('bpe v1\n')
            f.write(f"{self.pattern}\n")

            #write special tokens
            f.write(f"{len(self.special_tokens)}\n")
            for special,idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")

            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}\n")

        #writing vocab file(inspection)
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    f.write(f"[{s}] {idx}\n")

def load(self, model_file):
            assert model_file.endswith(".model")
            # read the model file
            merges = {}
            special_tokens = {}
            idx = 256
            with open(model_file, 'r', encoding="utf-8") as f:
                # read the version
                version = f.readline().strip()
                assert version == "bpe v1"
                self.pattern = f.readline().strip()
                # read the special tokens
                num_special = int(f.readline().strip())
                for _ in range(num_special):
                    special, special_idx = f.readline().strip().split()
                    special_tokens[special] = int(special_idx)
                # read the merges
                for line in f:
                    idx1, idx2 = map(int, line.split())
                    merges[(idx1, idx2)] = idx
                    idx += 1
            self.merges = merges
            self.special_tokens = special_tokens
            self.vocab = self._build_vocab()

Complete Tokenizer Class

So our completed tokenizer class will look like:

class Tokenizer:

    def __init__(self):
        self.merges = {}
        self.pattern = {}
        self.special_tokens = {} # eg <|endoftext|>
        self.vocab = self._build_vocab()

    def _build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(256)}

        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]

        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def train(self,text,vocab_size,verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}

        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats,key=stats.get)
            idx = 256 + i
            ids = merge(ids,pair,idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]]+ vocab[pair[1]]

            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        self.merges = merges #used in encode
        self.vocab = vocab #used in decode

    def decode(self,ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8",errors="replace")
        return text

    def encode(self,text):
        text_bytes = text.encode("utf-8")
        ids = list(text_bytes)

        while len(ids) >= 2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))

            if pair not in self.merges:
                break # no more merging is possible
            idx = self.merges[pair]
            ids = merge(ids,pair,idx)
        return ids

    def save(self,file_prefix):
        '''
        Two files are save file_prefix.vocab and file_prefix.model

        model file is the actually important file 
        vocab is for inspection
        '''
        model_file = file_prefix +'.model'
        with open(model_file,'w') as f:
            f.write('bpe v1\n')
            f.write(f"{self.pattern}\n")

            #write special tokens
            f.write(f"{len(self.special_tokens)}\n")
            for special,idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")

            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}\n")

        #writing vocab file(inspection)
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    f.write(f"[{s}] {idx}\n")

    def load(self, model_file):
            assert model_file.endswith(".model")
            # read the model file
            merges = {}
            special_tokens = {}
            idx = 256
            with open(model_file, 'r', encoding="utf-8") as f:
                # read the version
                version = f.readline().strip()
                assert version == "bpe v1"
                self.pattern = f.readline().strip()
                # read the special tokens
                num_special = int(f.readline().strip())
                for _ in range(num_special):
                    special, special_idx = f.readline().strip().split()
                    special_tokens[special] = int(special_idx)
                # read the merges
                for line in f:
                    idx1, idx2 = map(int, line.split())
                    merges[(idx1, idx2)] = idx
                    idx += 1
            self.merges = merges
            self.special_tokens = special_tokens
            self.vocab = self._build_vocab()

Now you may save your model as:

tokenizer.save(“model1”)

Then open your vocab file in text editor of your choice it will look something like:

Byte Pair Encoding merge rules showing how character pairs are combined into new tokens during tokenizer training

Finally let’s load it:

t1 = Tokenizer()
t1.load(“mode1.model”)
encode = t1.encode("hi")
decode = t1.decode(encode)
print(encode)
print(decode) //should return original text

That’s it we have completed our basic tokenizer now we in the next tutorial we will add a final touch to it by creating a regex layer to approximate how actual tokenization is done in large AI models.


Explore More Resources

Main Website: biterdevs.com

Tools:

Related Blog Posts:

Blog Home: blog.biterdevs.com

1 thought on “Tokenizer Series Part 2: From BPE Training to a Working Tokenizer Class”

  1. Pingback: Regex-Based Tokenization: Building Usable Tokens | Part 3

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top