Commit 8f8ab011 authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Fixed two bugs with the implementation:

- the devices from the pruning masks were set at the wrong time.
- Sparse initialization had a bug where the fan was 0, which lead to a division by zero
parent 9fdc006f
......@@ -71,6 +71,7 @@ def sparse_initialization_layer(layer: torch.nn.Module, mode, gain=1.0, fan_mode
if fan_mode == 'both':
fan = fan_in[i] + fan_out[j]
fan = max(fan, 1) # Just in case fan = 0, we don't want an error, rather if fan = 0, the weight is unused.
std = gain/math.sqrt(float(fan))
bound = math.sqrt(3.0) * std
if uniform:
......
......@@ -52,8 +52,8 @@ device = args.device if torch.cuda.is_available() else 'cpu'
in_channels = traindata[0][0].shape[0]
model = model_dict[args.model](nr_classes=traindata.classes, in_channels=in_channels).to(device)
initialize(model, args.initialization, False)
update_pruning_device(model)
initialize(model, args.initialization, False)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment