Commit 009fd410 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added a function to load weights and pruning masks from the numpy saved files....

Added a function to load weights and pruning masks from the numpy saved files. This is in case that we want to do further pruning.

Small fix for VGG, to prevent pruning the output layer.
parent f7ca0edf
......@@ -3,6 +3,7 @@ import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pruning import can_be_pruned
......@@ -28,3 +29,13 @@ class CustomModel(nn.Module):
if can_be_pruned(layer):
layer.init_weight = copy.deepcopy(layer.weight)
self.has_init_weights = True
def load_weights_masks(self, path, iteration):
path = f"{path}/{iteration}/"
for name, layer in self.named_modules():
if can_be_pruned(layer):
weight = np.load(f"{path}/{name}_weight.npy")
mask = np.load(f"{path}/{name}_mask.npy")
layer.weight = torch.nn.Parameter(torch.from_numpy(weight).to(layer.weight.device))
layer.pruning_mask = torch.nn.Parameter(torch.from_numpy(mask).to(layer.pruning_mask.device))
\ No newline at end of file
......@@ -25,7 +25,7 @@ class VGG11(CustomModel):
self.drop1 = nn.Dropout(0.5)
self.fc2 = add_pruning_info(nn.Linear(4096, 4096))
self.drop2 = nn.Dropout(0.5)
self.fc3 = add_pruning_info(nn.Linear(4096, self.nr_classes))
self.fc3 = nn.Linear(4096, self.nr_classes)
def forward(self, x):
x = self.conv1(x).relu()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment