Commit e395d75b authored by Benjamin Vandersmissen's avatar Benjamin Vandersmissen
Browse files

Added a function to reinitialize weights after pruning, rather than rewinding.

parent 52609ff8
......@@ -42,6 +42,11 @@ def reset_init_weights(model: torch.nn.Module):
layer.weight = torch.nn.Parameter(layer.init_weight * torch.logical_not(layer.pruning_mask)) # reset to init weights, EXCEPT FOR pruned weights.
def reinit_weights(model: torch.nn.Module):
for layer in model.modules():
if can_be_pruned(layer):
layer.reset_parameters()
def pruned_percentage(model: torch.nn.Module):
prunable_weights = 0
pruned_weights = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment