mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Make quote style consistent (#891)
This commit is contained in:
committed by
GitHub
parent
9276edbc37
commit
7ca7c47e4a
@@ -70,7 +70,7 @@ def get_pairs(word):
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
@@ -92,7 +92,7 @@ class Encoder:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
@@ -119,43 +119,43 @@ class Encoder:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(model_name, models_dir):
|
||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
|
||||
encoder = json.load(f)
|
||||
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
||||
return Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
||||
|
||||
|
||||
def download_vocab():
|
||||
# Modified code from
|
||||
subdir = 'gpt2_model'
|
||||
subdir = "gpt2_model"
|
||||
if not os.path.exists(subdir):
|
||||
os.makedirs(subdir)
|
||||
subdir = subdir.replace('\\', '/') # needed for Windows
|
||||
subdir = subdir.replace("\\", "/") # needed for Windows
|
||||
|
||||
for filename in ['encoder.json', 'vocab.bpe']:
|
||||
for filename in ["encoder.json", "vocab.bpe"]:
|
||||
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/models/117M/" + filename, stream=True)
|
||||
|
||||
with open(os.path.join(subdir, filename), 'wb') as f:
|
||||
with open(os.path.join(subdir, filename), "wb") as f:
|
||||
file_size = int(r.headers["content-length"])
|
||||
chunk_size = 1000
|
||||
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
|
||||
|
||||
Reference in New Issue
Block a user