Commit 2b5536d2 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added a new model to train and prune: VGG11

parent b173b74a
import torch
import torch.nn as nn
import torch.nn.functional as F
from pruning import can_be_pruned
class CustomModel(nn.Module):
def prob(self, x):
return F.softmax(self.forward(x), dim=1)
def keep_pruned_weights(self):
for layer in self.modules():
if can_be_pruned(layer):
with torch.no_grad():
layer.weight[layer.pruning_mask] = 0
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from pruning import add_pruning_info, can_be_pruned
from pruning import add_pruning_info
from customModel import CustomModel
class LeNet5(nn.Module):
class LeNet5(CustomModel):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = add_pruning_info(nn.Conv2d(1, 6, 5, padding=2))
......@@ -26,13 +25,4 @@ class LeNet5(nn.Module):
x = torch.flatten(x, start_dim=1)
x = self.fc1(x).tanh()
x = self.fc2(x)
return x
def prob(self, x):
return F.softmax(self.forward(x), dim=1)
def keep_pruned_weights(self):
for layer in self.modules():
if can_be_pruned(layer):
with torch.no_grad():
layer.weight[layer.pruning_mask] = 0
return x
\ No newline at end of file
from HDF5Dataset import *
from lenet import LeNet5
from vgg import VGG11
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
from pruning import *
import argparse
from datetime import datetime
from torchvision import transforms
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='fashion', type=str, help="The dataset to use")
......@@ -16,11 +18,24 @@ parser.add_argument('-b', '--batch', default=32, type=int, help="Batch size")
parser.add_argument('-t', '--trainepochs', default=10, type=int, help="How many epochs do we train")
parser.add_argument('-r', '--random', default=42, type=int, help="The random seed used")
parser.add_argument('--device', default='cuda:0', type=str, help="The device to run on")
parser.add_argument('-m', '--model', default='lenet', type=str, choices=['lenet', 'vgg'], help="The model to prune")
args = parser.parse_args()
model_dict = {
"lenet": LeNet5,
"vgg": VGG11
}
traindata = HDF5Dataset("hdf5/{}/train.hdf5".format(args.data))
testdata = HDF5Dataset("hdf5/{}/test.hdf5".format(args.data))
if args.model == 'vgg':
assert args.data == 'flowers'
train_trans = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
test_trans = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
else:
train_trans = None
test_trans = None
traindata = HDF5Dataset("hdf5/{}/train.hdf5".format(args.data), transform=train_trans)
testdata = HDF5Dataset("hdf5/{}/test.hdf5".format(args.data), transform=test_trans)
torch.manual_seed(args.random)
batch_size = args.batch
......@@ -28,7 +43,7 @@ lr = args.lr
device = args.device if torch.cuda.is_available() else 'cpu'
model = LeNet5().to(device)
model = model_dict[args.model]().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
......@@ -108,5 +123,5 @@ for i in range(args.pruningepochs):
prune_by_magnitude(model)
reset_init_weights(model)
save_weights_pruning_masks(model, 10, basedir)
save_weights_pruning_masks(model, args.pruningepochs, basedir)
import torch
import torch.nn as nn
from pruning import add_pruning_info
from customModel import CustomModel
class VGG11(CustomModel):
def __init__(self, nr_classes=17):
super(VGG11, self).__init__()
self.conv1 = add_pruning_info(nn.Conv2d(3, 64, kernel_size=3, padding=1))
self.pool1 = nn.MaxPool2d(2)
self.conv2 = add_pruning_info(nn.Conv2d(64, 128, kernel_size=3, padding=1))
self.pool2 = nn.MaxPool2d(2)
self.conv3 = add_pruning_info(nn.Conv2d(128, 256, kernel_size=3, padding=1))
self.conv4 = add_pruning_info(nn.Conv2d(256, 256, kernel_size=3, padding=1))
self.pool3 = nn.MaxPool2d(2)
self.conv5 = add_pruning_info(nn.Conv2d(256, 512, kernel_size=3, padding=1))
self.conv6 = add_pruning_info(nn.Conv2d(512, 512, kernel_size=3, padding=1))
self.pool4 = nn.MaxPool2d(2)
self.conv7 = add_pruning_info(nn.Conv2d(512, 512, kernel_size=3, padding=1))
self.conv8 = add_pruning_info(nn.Conv2d(512, 512, kernel_size=3, padding=1))
self.pool5 = nn.MaxPool2d(2)
self.fc1 = add_pruning_info(nn.Linear(512*7*7, 4096))
self.drop1 = nn.Dropout(0.5)
self.fc2 = add_pruning_info(nn.Linear(4096, 4096))
self.drop2 = nn.Dropout(0.5)
self.fc3 = add_pruning_info(nn.Linear(4096, nr_classes))
def forward(self, x):
x = self.conv1(x).relu()
x = self.pool1(x)
x = self.conv2(x).relu()
x = self.pool2(x)
x = self.conv3(x).relu()
x = self.conv4(x).relu()
x = self.pool3(x)
x = self.conv5(x).relu()
x = self.conv6(x).relu()
x = self.pool4(x)
x = self.conv7(x).relu()
x = self.conv8(x).relu()
x = self.pool5(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x).relu()
x = self.drop1(x)
x = self.fc2(x).relu()
x = self.drop2(x)
x = self.fc3(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment