Commit 62a0bade authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added code to make models independent of the initial number of channels in an...

Added code to make models independent of the initial number of channels in an image and automatically determine the number of channels needed.
parent 948674b7
......@@ -6,9 +6,10 @@ from pruning import can_be_pruned
class CustomModel(nn.Module):
def __init__(self, nr_classes=10):
def __init__(self, nr_classes=10, in_channels=3):
super(CustomModel, self).__init__()
self.nr_classes = nr_classes
self.in_channels = in_channels
def prob(self, x):
return F.softmax(self.forward(x), dim=1)
......
......@@ -8,7 +8,7 @@ from customModel import CustomModel
class LeNet5(CustomModel):
def __init__(self, **kwargs):
super(LeNet5, self).__init__(**kwargs)
self.conv1 = add_pruning_info(nn.Conv2d(1, 6, 5, padding=2))
self.conv1 = add_pruning_info(nn.Conv2d(self.in_channels, 6, 5, padding=2))
self.pool1 = nn.AvgPool2d(2)
self.conv2 = add_pruning_info(nn.Conv2d(6, 16, 5))
self.pool2 = nn.AvgPool2d(2)
......
......@@ -27,7 +27,6 @@ model_dict = {
}
if args.model == 'vgg':
assert args.data == 'flowers'
train_trans = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()])
test_trans = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
else:
......@@ -43,7 +42,8 @@ lr = args.lr
device = args.device if torch.cuda.is_available() else 'cpu'
model = model_dict[args.model](nr_classes=traindata.classes).to(device)
in_channels = traindata[0][0].shape[0]
model = model_dict[args.model](nr_classes=traindata.classes, in_channels=in_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
......
......@@ -8,7 +8,7 @@ from customModel import CustomModel
class VGG11(CustomModel):
def __init__(self, **kwargs):
super(VGG11, self).__init__(**kwargs)
self.conv1 = add_pruning_info(nn.Conv2d(3, 64, kernel_size=3, padding=1))
self.conv1 = add_pruning_info(nn.Conv2d(self.in_channels, 64, kernel_size=3, padding=1))
self.pool1 = nn.MaxPool2d(2)
self.conv2 = add_pruning_info(nn.Conv2d(64, 128, kernel_size=3, padding=1))
self.pool2 = nn.MaxPool2d(2)
......@@ -20,7 +20,7 @@ class VGG11(CustomModel):
self.pool4 = nn.MaxPool2d(2)
self.conv7 = add_pruning_info(nn.Conv2d(512, 512, kernel_size=3, padding=1))
self.conv8 = add_pruning_info(nn.Conv2d(512, 512, kernel_size=3, padding=1))
self.pool5 = nn.MaxPool2d(2)
self.pool5 = nn.AdaptiveMaxPool2d((7, 7))
self.fc1 = add_pruning_info(nn.Linear(512*7*7, 4096))
self.drop1 = nn.Dropout(0.5)
self.fc2 = add_pruning_info(nn.Linear(4096, 4096))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment