Commit a5c0ce99 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Made it so that the hyperparameters are saved to a file for each run.

parent fe521b06
......@@ -9,6 +9,9 @@ import argparse
from datetime import datetime
from torchvision import transforms
# TODO: lr scheduler support.
# TODO: save the gradients somehow (maybe interesting for only a small batch during training)
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='fashion', type=str, help="The dataset to use")
parser.add_argument('--lr', default=0.0002, type=float, help="Which learning rate to use")
......@@ -51,6 +54,10 @@ criterion = nn.CrossEntropyLoss()
basedir = "weights/"+str(int(
os.makedirs(basedir, exist_ok=True)
with open(f"{basedir}/settings", "w") as f:
for key, val in args.__dict__.items():
f.write(f"{key} : {val}\n")
def iterate_over_dataset(dataloader, model, criterion, optimizer, device, train=True):
model.train() if train else model.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment