Commit 87c962a1 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added a function to save the initial weights specifically.

Added a fix to ensure that the pruning masks are on the same device as the weights
Cleanup of the main loop
Added a way to select the initialization of layers, including the sparse-aware initialization from [Evci2022]
parent 009fd410
import math
import torch
from pruning import can_be_pruned
def initialize(model: torch.nn.Module, mode: str = "default", reinit=False):
sparse = 'sparse' in mode
kaiming = 'kaiming' in mode
xavier = 'xavier' in mode
normal = 'normal' in mode
uniform = 'uniform' in mode
for layer in model.modules():
if not hasattr(layer, 'weight'):
continue
with torch.no_grad():
if sparse and reinit and can_be_pruned(layer): # sparse initialization is only needed when re-initializing
sparse_initialization_layer(layer, mode)
else:
if kaiming and normal:
torch.nn.init.kaiming_normal_(layer.weight)
elif kaiming and uniform:
torch.nn.init.kaiming_uniform_(layer.weight)
elif xavier and normal:
torch.nn.init.xavier_normal_(layer.weight)
elif xavier and uniform:
torch.nn.init.xavier_uniform_(layer.weight)
else:
layer.reset_parameters() # Default behaviour, shouldn't be used
if can_be_pruned(layer):
layer.weight = torch.nn.Parameter(layer.weight * torch.logical_not(layer.pruning_mask))
if hasattr(layer, 'bias'):
layer.bias = torch.nn.Parameter(torch.zeros_like(layer.bias))
def calculate_sparse_fanin_fanout(weights, mask):
"""
@param weights: the weight matrix to calculate the fan-in & fan-out for
@param mask: A pruning mask for the weight matrix (1 = pruned, 0 = active)
"""
nr_dim = len(weights.shape)
dim1 = [i for i in range(nr_dim) if i != 0] # General case for Linear and Conv2D (or Conv3D if ever needed)
dim0 = [i for i in range(nr_dim) if i != 1]
mask = torch.logical_not(mask)
fan_in = torch.sum(mask, dim=dim1)
fan_out = torch.sum(mask, dim=dim0)
return fan_in, fan_out
def sparse_initialization_layer(layer: torch.nn.Module, mode, gain=1.0, fan_mode="in"):
xavier = 'xavier' in mode
normal = 'normal' in mode
uniform = 'uniform' in mode
if xavier:
# Kaiming: gain * sqrt(1/fan)
# Xavier: gain * sqrt(2/(fan_in+fan_out))
# Transforming between them can be done by setting fan=fan_in+fan_out, gain = gain * sqrt(2)
fan_mode = 'both'
gain = gain * math.sqrt(2)
fan_in, fan_out = calculate_sparse_fanin_fanout(layer.weight, layer.pruning_mask)
init_weights = torch.empty_like(layer.weight)
for i in range(layer.weight.shape[0]):
for j in range(layer.weight.shape[1]):
fan = fan_in[i]
if fan_mode == 'out':
fan = fan_out[j]
if fan_mode == 'both':
fan = fan_in[i] + fan_out[j]
std = gain/math.sqrt(float(fan))
bound = math.sqrt(3.0) * std
if uniform:
init_weights[i, j] = init_weights[i, j].uniform_(-bound, bound)
elif normal:
init_weights[i, j] = init_weights[i, j].normal(0, std)
else:
raise Exception(f"Undefined mode: {mode}")
layer.weight = torch.nn.Parameter(init_weights)
\ No newline at end of file
......@@ -10,9 +10,9 @@ from lenet import LeNet5
from vgg import VGG11
from pruning import *
from log import *
from init import *
# TODO: lr scheduler support.
# TODO: save the gradients somehow (maybe interesting for only a small batch during training)
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='fashion', type=str, help="The dataset to use")
......@@ -25,6 +25,7 @@ parser.add_argument('-r', '--random', default=42, type=int, help="The random see
parser.add_argument('--device', default='cuda:0', type=str, help="The device to run on")
parser.add_argument('-m', '--model', default='lenet', type=str, choices=['lenet', 'vgg'], help="The model to prune")
parser.add_argument('-o', '--original-epoch', default=0, type=int, help='The epoch to reset the weights to')
parser.add_argument('--initialization', default='kaiming_uniform', type=str, help='The initialization method to use')
parser.add_argument('-i', '--importance', default='none', choices=['none', 'softmax', 'normalize', 'normalize_sum'])
args = parser.parse_args()
......@@ -51,6 +52,9 @@ device = args.device if torch.cuda.is_available() else 'cpu'
in_channels = traindata[0][0].shape[0]
model = model_dict[args.model](nr_classes=traindata.classes, in_channels=in_channels).to(device)
initialize(model, args.initialization, False)
update_pruning_device(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
......@@ -123,22 +127,24 @@ def training_loop(model, criterion, optimizer, train_loader, valid_loader, epoch
transformations = {'none': identity, 'softmax': softmax, 'normalize': normalize, 'normalize_sum': normalize_sum}
save_init_weights(model, basedir)
for i in range(args.pruningepochs):
if i != 0: # Update the pruning mask if at least one iteration has been done
prune_by_magnitude(model, args.pruningpercentage, transformations[args.importance])
save_weights_pruning_masks(model, i, basedir) # Save the weights at the start of each iteration -> weights at start of iteration 1 have been pruned once
if i != 0: # Reset to the initial weights if at least one iteration has been done
reset_init_weights(model)
save_weights_pruning_masks(model, i, basedir)
if model.has_init_weights:
reset_init_weights(model)
else:
initialize(model, args.initialization, True)
trainloader = DataLoader(traindata, batch_size, shuffle=True)
testloader = DataLoader(testdata, batch_size, shuffle=False)
epochs = args.trainepochs if i == 0 else args.trainepochs - args.original_epoch # If we late reset, we don't need to do those initial epochs
epochs = args.trainepochs if i == 0 else args.trainepochs - max(args.original_epoch, 0) # If we late reset, we don't need to do those initial epochs
model, _, losses, accuracies = training_loop(model, criterion, optimizer, trainloader, testloader, epochs, device)
print("Losses: {}".format(losses))
print("Accuracies: {}".format(accuracies))
log_accuracies(basedir, accuracies)
print("Pruning % : {}".format(pruned_percentage(model)))
print("Pruning % : {} ({}/{})".format(pruned_percentage(model), i+1, args.pruningepochs))
prune_by_magnitude(model)
save_weights_pruning_masks(model, args.pruningepochs, basedir)
......@@ -9,6 +9,12 @@ def add_pruning_info(module):
return module
def update_pruning_device(model):
for layer in model.modules():
if can_be_pruned(layer):
layer.pruning_mask = layer.pruning_mask.to(layer.weight.device)
def can_be_pruned(module):
return hasattr(module, 'pruning_mask')
......@@ -31,6 +37,7 @@ def normalize_sum(inp):
def prune_by_magnitude(model: torch.nn.Module, percentage=0.2, transformation=identity):
weights = None
device = next(model.parameters()).device
for layer in model.modules():
if can_be_pruned(layer):
reshapen_weight = transformation(torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)])).reshape(-1)
......@@ -45,7 +52,7 @@ def prune_by_magnitude(model: torch.nn.Module, percentage=0.2, transformation=id
if can_be_pruned(layer):
layer.pruning_mask[torch.logical_not(layer.pruning_mask)] = \
torch.lt(transformation(torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)])), quantile)
torch.lt(transformation(torch.abs(layer.weight[torch.logical_not(layer.pruning_mask)])), torch.FloatTensor([quantile]).to(device))
def reset_init_weights(model: torch.nn.Module):
......@@ -54,12 +61,6 @@ def reset_init_weights(model: torch.nn.Module):
layer.weight = torch.nn.Parameter(layer.init_weight * torch.logical_not(layer.pruning_mask)) # reset to init weights, EXCEPT FOR pruned weights.
def reinit_weights(model: torch.nn.Module):
for layer in model.modules():
if can_be_pruned(layer):
layer.reset_parameters()
def pruned_percentage(model: torch.nn.Module):
prunable_weights = 0
pruned_weights = 0
......@@ -84,6 +85,16 @@ def pruned_overall_percentage(model: torch.nn.Module):
return pruned_weights/all_weights
def save_init_weights(model: torch.nn.Module, prefix="weights"):
path = prefix + "/initial/"
os.makedirs(path, exist_ok=True)
for name, layer in model.named_modules():
if can_be_pruned(layer):
weight = layer.weight.detach().cpu().numpy()
np.save("{}/{}_weight.npy".format(path, name), weight)
def save_weights_pruning_masks(model: torch.nn.Module, iteration=0, prefix="weights"):
path = prefix + "/" + str(iteration) + "/"
os.makedirs(path, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment