Commit 182b01e9 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added additional logging in a new file, and different options to determine the...

Added additional logging in a new file, and different options to determine the local importance (softmax, identity = Global pruning, normalize, normalize by std)
parent a5c0ce99
def log_arguments(basedir, args):
with open(f"{basedir}/settings", "w") as f:
for key, val in args.__dict__.items():
f.write(f"{key} : {val}\n")
def log_accuracies(basedir, accuracies):
with open(f"{basedir}/accuracies", "a") as f:
f.write(f"{accuracies}\n")
\ No newline at end of file
from HDF5Dataset import *
from lenet import LeNet5
from vgg import VGG11
from torch.utils.data import DataLoader
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from pruning import *
import argparse
from datetime import datetime
from torchvision import transforms
from HDF5Dataset import *
from lenet import LeNet5
from vgg import VGG11
from pruning import *
from log import *
# TODO: lr scheduler support.
# TODO: save the gradients somehow (maybe interesting for only a small batch during training)
......@@ -23,6 +25,7 @@ parser.add_argument('-r', '--random', default=42, type=int, help="The random see
parser.add_argument('--device', default='cuda:0', type=str, help="The device to run on")
parser.add_argument('-m', '--model', default='lenet', type=str, choices=['lenet', 'vgg'], help="The model to prune")
parser.add_argument('-o', '--original-epoch', default=0, type=int, help='The epoch to reset the weights to')
parser.add_argument('-i', '--importance', default='none', choices=['none', 'softmax', 'normalize', 'normalize_std'])
args = parser.parse_args()
model_dict = {
......@@ -54,9 +57,7 @@ criterion = nn.CrossEntropyLoss()
basedir = "weights/"+str(int(datetime.now().timestamp()))
os.makedirs(basedir, exist_ok=True)
with open(f"{basedir}/settings", "w") as f:
for key, val in args.__dict__.items():
f.write(f"{key} : {val}\n")
log_arguments(basedir, args)
def iterate_over_dataset(dataloader, model, criterion, optimizer, device, train=True):
......@@ -120,18 +121,24 @@ def training_loop(model, criterion, optimizer, train_loader, valid_loader, epoch
return model, optimizer, (train_losses, valid_losses), (train_accuracies, valid_accuracies)
transformations = {'none': identity, 'softmax': softmax, 'normalize': normalize, 'normalize_std': normalize_std}
for i in range(args.pruningepochs):
if i != 0: # Update the pruning mask if at least one iteration has been done
prune_by_magnitude(model, args.pruningpercentage, transformations[args.importance])
save_weights_pruning_masks(model, i, basedir) # Save the weights at the start of each iteration -> weights at start of iteration 1 have been pruned once
if i != 0: # Reset to the initial weights if at least one iteration has been done
reset_init_weights(model)
trainloader = DataLoader(traindata, batch_size, shuffle=True)
testloader = DataLoader(testdata, batch_size, shuffle=False)
epochs = args.trainepochs if i == 0 else args.trainepochs - args.original_epoch # If we late reset, we don't need to do those initial epochs?
epochs = args.trainepochs if i == 0 else args.trainepochs - args.original_epoch # If we late reset, we don't need to do those initial epochs
model, _, losses, accuracies = training_loop(model, criterion, optimizer, trainloader, testloader, epochs, device)
print("Losses: {}".format(losses))
print("Accuracies: {}".format(accuracies))
log_accuracies(basedir, accuracies)
print("Pruning % : {}".format(pruned_percentage(model)))
save_weights_pruning_masks(model, i, basedir)
prune_by_magnitude(model)
reset_init_weights(model)
save_weights_pruning_masks(model, args.pruningepochs, basedir)
prune_by_magnitude(model)
save_weights_pruning_masks(model, args.pruningepochs, basedir)
\ No newline at end of file
import torch
import copy
import math
import os
import numpy as np
......@@ -14,11 +15,34 @@ def can_be_pruned(module):
return hasattr(module, 'pruning_mask')
def prune_by_magnitude(model: torch.nn.Module, percentage=0.2):
def identity(inp):
return inp
def normalize(inp):
return (inp-inp.min()) / (inp.max()-inp.min())
def softmax(inp):
return torch.softmax(inp, dim=-1)
def normalize_std(inp):
"""
Divides by std of the initialization,
but because the std is only variable in the term math.sqrt(1/float(fan_in + fan_out)),
we divide by that instead
"""
# TODO: this is only for GLOROT UNIFORM, implement the others when needed
fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(inp)
return inp / math.sqrt(1 / float(fan_in + fan_out))
def prune_by_magnitude(model: torch.nn.Module, percentage=0.2, transformation=identity):
weights = None
for layer in model.modules():
if can_be_pruned(layer):
reshapen_weight = torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)]).reshape(-1)
reshapen_weight = transformation(torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)])).reshape(-1)
if weights is None:
weights = reshapen_weight
else:
......@@ -29,9 +53,8 @@ def prune_by_magnitude(model: torch.nn.Module, percentage=0.2):
for layer in model.modules():
if can_be_pruned(layer):
device = layer.weight.device
mask = torch.lt(torch.abs(layer.weight), quantile) # 0 means that the weight is active, 1 that it is inactive
layer.pruning_mask = torch.logical_or(layer.pruning_mask.to(device), mask.to(device))
layer.pruning_mask[torch.logical_not(layer.pruning_mask)] = \
torch.lt(transformation(torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)])), quantile)
def reset_init_weights(model: torch.nn.Module):
......
{
"name": "LTH test VGG11",
"name": "LTH VGG11",
"deploymentEnvironment": "production",
"request": {
"resources": {
......@@ -12,12 +12,12 @@
"docker": {
"image": "gitlab+deploy-token-535:19ESJYmR2qaQn5toyB8s@gitlab.ilabt.imec.be:4567/sparse-representation-learning/lth:latest",
"environment":{
"LTH_DATA": "flowers",
"LTH_LR": 0.00025,
"LTH_N": 25,
"LTH_DATA": "imagenette",
"LTH_LR": 0.0005,
"LTH_N": 20,
"LTH_P": 0.2,
"LTH_B": 32,
"LTH_T": 15,
"LTH_B": 64,
"LTH_T": 2,
"LTH_R": 42,
"LTH_DEVICE":"cuda:0",
"LTH_M": "vgg"
......
......@@ -27,7 +27,7 @@
"hostPath": "/project_antwerp"
}
],
"command": "bash -c 'apt-get update; apt-get install -y rsync vim; sleep 3600'"
"command": "bash -c 'apt-get update; apt-get install -y rsync vim; sleep 36000'"
}
},
"description": "transfering files to the LTH project"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment