add weight sizes

This commit is contained in:
rasbt
2024-03-31 08:45:14 -05:00
parent 1c173e4f44
commit 83adc4a2ac
3 changed files with 90 additions and 95 deletions

View File

@@ -219,7 +219,7 @@ if __name__ == "__main__":
torch.manual_seed(123)
CHOOSE_MODEL = "gpt2-small"
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
@@ -230,19 +230,14 @@ if __name__ == "__main__":
}
model_configs = {
"gpt2-small": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
model_sizes = {
"gpt2-small": "124M",
"gpt2-medium": "355M",
"gpt2-large": "774M",
"gpt2-xl": "1558"
}
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
main(BASE_CONFIG, INPUT_PROMPT, model_sizes[CHOOSE_MODEL])
main(BASE_CONFIG, INPUT_PROMPT, model_size)