1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
|
import os from logging import getLogger from pathlib import Path from typing import ( AbstractSet, cast, Collection, Dict, Iterator, List, Literal, Sequence, TypedDict, Union, )
import tiktoken from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
Role = Literal["system", "user", "assistant"]
class Message(TypedDict): role: Role content: str
Dialog = Sequence[Message]
class Tokenizer: """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
def __init__(self, model_path: str): """ Initializes the Tokenizer with a Tiktoken model.
Args: model_path (str): The path to the Tiktoken model file. """ assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path) num_base_tokens = len(mergeable_ranks) special_tokens = [ "<|begin_of_text|>", "<|end_of_text|>", "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|start_header_id|>", "<|end_header_id|>", "<|reserved_special_token_4|>", "<|eot_id|>", ] + [ f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5) ] self.special_tokens = { token: num_base_tokens + i for i, token in enumerate(special_tokens) } self.model = tiktoken.Encoding( name=Path(model_path).name, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens, ) logger.info(f"Reloaded tiktoken model from {model_path}")
self.n_words: int = self.model.n_vocab self.bos_id: int = self.special_tokens["<|begin_of_text|>"] self.eos_id: int = self.special_tokens["<|end_of_text|>"] self.pad_id: int = -1 self.stop_tokens = { self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"], } logger.info( f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" )
def encode( self, s: str, *, bos: bool, eos: bool, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = (), ) -> List[int]: """ Encodes a string into a list of token IDs.
Args: s (str): The input string to be encoded. bos (bool): Whether to prepend the beginning-of-sequence token. eos (bool): Whether to append the end-of-sequence token. allowed_tokens ("all"|set[str]): allowed special tokens in string disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns: list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: - Setting `disallowed_special` to () will cause all text corresponding to special tokens to be encoded as natural text (insteading of raising an error). - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. """ assert type(s) is str
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
MAX_NO_WHITESPACES_CHARS = 25_000
substrs = ( substr for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) for substr in self._split_whitespaces_or_nonwhitespaces( s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS ) ) t: List[int] = [] for substr in substrs: t.extend( self.model.encode( substr, allowed_special=allowed_special, disallowed_special=disallowed_special, ) ) if bos: t.insert(0, self.bos_id) if eos: t.append(self.eos_id) return t
def decode(self, t: Sequence[int]) -> str: """ Decodes a list of token IDs into a string.
Args: t (List[int]): The list of token IDs to be decoded.
Returns: str: The decoded string. """ return self.model.decode(cast(List[int], t))
@staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int ) -> Iterator[str]: """ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` consecutive whitespaces or consecutive non-whitespaces. """ current_slice_len = 0 current_slice_is_space = s[0].isspace() if len(s) > 0 else False slice_start = 0
for i in range(len(s)): is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space: current_slice_len = 1 current_slice_is_space = is_now_space else: current_slice_len += 1 if current_slice_len > max_consecutive_slice_len: yield s[slice_start:i] slice_start = i current_slice_len = 1 yield s[slice_start:]
class ChatFormat: def __init__(self, tokenizer: Tokenizer): self.tokenizer = tokenizer
def encode_header(self, message: Message) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) return tokens
def encode_message(self, message: Message) -> List[int]: tokens = self.encode_header(message) tokens.extend( self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) ) tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens
def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) for message in dialog: tokens.extend(self.encode_message(message)) tokens.extend(self.encode_header({"role": "assistant", "content": ""})) return tokens
|