new experiment w/o causal mask

This commit is contained in:
rasbt
2024-05-18 17:03:36 -05:00
parent 57634f2045
commit 5ef4edf2b5
3 changed files with 30 additions and 11 deletions

View File

@@ -153,7 +153,7 @@ def instantiate_model(choose_model, load_weights):
if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
model = GPTModel(BASE_CONFIG, disable_causal_mask=args.disable_causal_mask)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
@@ -386,6 +386,15 @@ if __name__ == "__main__":
)
)
parser.add_argument(
"--disable_causal_mask",
action='store_true',
default=False,
help=(
"Disables the causal attention mask."
)
)
args = parser.parse_args()
if args.trainable_token == "first":