mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
rename hparams to settings
This commit is contained in:
@@ -37,7 +37,7 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
model_dir = os.path.join(models_dir, model_size)
|
||||
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
|
||||
filenames = [
|
||||
"checkpoint", "encoder.json", "hparams.json",
|
||||
"checkpoint", "encoder.json", "settings.json",
|
||||
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
|
||||
"model.ckpt.meta", "vocab.bpe"
|
||||
]
|
||||
@@ -49,12 +49,12 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
file_path = os.path.join(model_dir, filename)
|
||||
download_file(file_url, file_path)
|
||||
|
||||
# Load hparams and params
|
||||
# Load settings and params
|
||||
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
|
||||
hparams = json.load(open(os.path.join(model_dir, "hparams.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams)
|
||||
settings = json.load(open(os.path.join(model_dir, "settings.json")))
|
||||
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
|
||||
|
||||
return hparams, params
|
||||
return settings, params
|
||||
|
||||
|
||||
def download_file(url, destination):
|
||||
@@ -85,9 +85,9 @@ def download_file(url, destination):
|
||||
file.write(chunk) # Write the chunk to the file
|
||||
|
||||
|
||||
def load_gpt2_params_from_tf_ckpt(ckpt_path, hparams):
|
||||
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
|
||||
# Initialize parameters dictionary with empty blocks for each layer
|
||||
params = {"blocks": [{} for _ in range(hparams["n_layer"])]}
|
||||
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
|
||||
|
||||
# Iterate over each variable in the checkpoint
|
||||
for name, _ in tf.train.list_variables(ckpt_path):
|
||||
@@ -221,7 +221,7 @@ def main(gpt_config, input_prompt, model_size):
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
hparams, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
|
||||
gpt = GPTModel(gpt_config)
|
||||
load_weights_into_gpt(gpt, params)
|
||||
|
||||
Reference in New Issue
Block a user