Commit f7ca0edf authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Updated arguments in main.py

parent ca783c0e
......@@ -25,7 +25,7 @@ parser.add_argument('-r', '--random', default=42, type=int, help="The random see
parser.add_argument('--device', default='cuda:0', type=str, help="The device to run on")
parser.add_argument('-m', '--model', default='lenet', type=str, choices=['lenet', 'vgg'], help="The model to prune")
parser.add_argument('-o', '--original-epoch', default=0, type=int, help='The epoch to reset the weights to')
parser.add_argument('-i', '--importance', default='none', choices=['none', 'softmax', 'normalize', 'normalize_std'])
parser.add_argument('-i', '--importance', default='none', choices=['none', 'softmax', 'normalize', 'normalize_sum'])
args = parser.parse_args()
model_dict = {
......@@ -121,7 +121,7 @@ def training_loop(model, criterion, optimizer, train_loader, valid_loader, epoch
return model, optimizer, (train_losses, valid_losses), (train_accuracies, valid_accuracies)
transformations = {'none': identity, 'softmax': softmax, 'normalize': normalize, 'normalize_std': normalize_std}
transformations = {'none': identity, 'softmax': softmax, 'normalize': normalize, 'normalize_sum': normalize_sum}
for i in range(args.pruningepochs):
if i != 0: # Update the pruning mask if at least one iteration has been done
......@@ -141,4 +141,4 @@ for i in range(args.pruningepochs):
print("Pruning % : {}".format(pruned_percentage(model)))
prune_by_magnitude(model)
save_weights_pruning_masks(model, args.pruningepochs, basedir)
\ No newline at end of file
save_weights_pruning_masks(model, args.pruningepochs, basedir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment