Commit 3c13034b authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added code to select the weights used for reinitialization, rather than using...

Added code to select the weights used for reinitialization, rather than using the weights from the initialization.

Selecting the weight is done using the parameter -o / --original-epoch. When providing this parameter with the value 0, you will still use the original initialization weights.

The first pruning iteration will train for trainepochs, while the other iterations will train for trainepochs - original-epoch
parent 62a0bade
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -10,6 +12,7 @@ class CustomModel(nn.Module):
super(CustomModel, self).__init__()
self.nr_classes = nr_classes
self.in_channels = in_channels
self.has_init_weights = False
def prob(self, x):
return F.softmax(self.forward(x), dim=1)
......@@ -19,3 +22,9 @@ class CustomModel(nn.Module):
if can_be_pruned(layer):
with torch.no_grad():
layer.weight[layer.pruning_mask] = 0
def save_init_weights(self):
for layer in self.modules():
if can_be_pruned(layer):
layer.init_weight = copy.deepcopy(layer.weight)
self.has_init_weights = True
......@@ -19,6 +19,7 @@ parser.add_argument('-t', '--trainepochs', default=10, type=int, help="How many
parser.add_argument('-r', '--random', default=42, type=int, help="The random seed used")
parser.add_argument('--device', default='cuda:0', type=str, help="The device to run on")
parser.add_argument('-m', '--model', default='lenet', type=str, choices=['lenet', 'vgg'], help="The model to prune")
parser.add_argument('-o', '--original-epoch', default=0, type=int, help='The epoch to reset the weights to')
args = parser.parse_args()
model_dict = {
......@@ -93,6 +94,8 @@ def training_loop(model, criterion, optimizer, train_loader, valid_loader, epoch
# Train model
for epoch in range(0, epochs):
if not model.has_init_weights and args.original_epoch == epoch:
model.save_init_weights()
print("Training: epoch {} / {}".format(epoch+1, epochs))
# training
......@@ -113,7 +116,8 @@ def training_loop(model, criterion, optimizer, train_loader, valid_loader, epoch
for i in range(args.pruningepochs):
trainloader = DataLoader(traindata, batch_size, shuffle=True)
testloader = DataLoader(testdata, batch_size, shuffle=False)
model, _, losses, accuracies = training_loop(model, criterion, optimizer, trainloader, testloader, args.trainepochs, device)
epochs = args.trainepochs if i == 0 else args.trainepochs - args.original_epoch # If we late reset, we don't need to do those initial epochs?
model, _, losses, accuracies = training_loop(model, criterion, optimizer, trainloader, testloader, epochs, device)
print("Losses: {}".format(losses))
print("Accuracies: {}".format(accuracies))
print("Pruning % : {}".format(pruned_percentage(model)))
......
......@@ -7,7 +7,6 @@ import numpy as np
def add_pruning_info(module):
assert hasattr(module, 'weight')
module.pruning_mask = torch.zeros_like(module.weight, dtype=torch.bool)
module.init_weight = copy.deepcopy(module.weight)
return module
......@@ -16,7 +15,6 @@ def can_be_pruned(module):
def prune_by_magnitude(model: torch.nn.Module, percentage=0.2):
# TODO: prune on magnitude (absolute), calculate quantile without pruned values
weights = None
for layer in model.modules():
if can_be_pruned(layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment