-
Notifications
You must be signed in to change notification settings - Fork 334
/
metapruner.py
624 lines (544 loc) · 35.9 KB
/
metapruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
import torch
import torch.nn as nn
import typing, warnings
from torch_pruning.pruner.importance import OBDCImportance
from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency
class MetaPruner:
"""
Meta pruner for structural pruning.
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* global_pruning (bool): enable global pruning. Default: False.
* pruning_ratio (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* pruning_ratio_dict (Dict[nn.Module|Tuple[nn.Module], float]): layer-specific pruning ratio. Default: None. The key of the dict can be a single module or a tuple of modules. The pruning ratio will be shared by all modules in the tuple.
* max_pruning_ratio (float): the maximum pruning ratio. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_pruning_ratio_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
* isomorphic (bool): enable isomorphic pruning. Default: False. https://arxiv.org/abs/2407.04616
# Adavanced
* in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
* out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
* num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* prune_num_heads (bool): remove entire heads in multi-head attention. Default: False.
* prune_head_dims (bool): remove head dimensions in multi-head attention. Default: True.
* head_pruning_ratio (float): head pruning ratio. Default: 0.0.
* head_pruning_ratio_dict (Dict[nn.Module, float]): layer-specific head pruning ratio. Default: None.
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
# Deprecated
* channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
* ch_sparsity (float): the same as pruning_ratio. Default: None.
* ch_sparsity_dict (Dict[nn.Module, float]): the same as pruning_ratio_dict. Default: None.
"""
def __init__(
self,
# Basic
model: nn.Module, # a simple pytorch model
example_inputs: torch.Tensor, # a dummy input for graph tracing. Should be on the same
importance: typing.Callable, # tp.importance.Importance for group importance estimation
global_pruning: bool = False, # https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#global-pruning.
pruning_ratio: float = 0.5, # channel/dim pruning ratio, also known as pruning ratio
pruning_ratio_dict: typing.Dict[typing.Union[nn.Module, typing.Tuple[nn.Module]], float] = None, # layer-specific pruning ratio. Will cover pruning_ratio if specified. The key of the dict can be a single module or a tuple of modules. The pruning ratio will be shared by all modules in the tuple.
max_pruning_ratio: float = 1.0, # maximum pruning ratio. useful if over-pruning happens.
iterative_steps: int = 1, # for iterative pruning
iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler, # scheduler for iterative pruning.
ignored_layers: typing.List[nn.Module] = None, # ignored layers
round_to: int = None, # round channels to the nearest multiple of round_to
isomorphic: bool = False, # enable isomorphic pruning (ECCV 2024, https://arxiv.org/abs/2407.04616) if global_pruning=True.
# Advanced
in_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer input
out_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer output
num_heads: typing.Dict[nn.Module, int] = dict(), # The number of heads for multi-head attention
prune_num_heads: bool = False, # remove entire heads in multi-head attention
prune_head_dims: bool = True, # remove head dimensions in multi-head attention
head_pruning_ratio: float = 0.0, # head pruning ratio
head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None, # layer-specific head pruning ratio
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner}
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0}
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group
forward_fn: typing.Callable = None, # a function to execute model.forward
output_transform: typing.Callable = None, # a function to transform network outputs
# deprecated
channel_groups: typing.Dict[nn.Module, int] = dict(), # channel grouping
ch_sparsity: float = None,
ch_sparsity_dict: typing.Dict[nn.Module, float] = None,
):
self.model = model
self.importance = importance
if ch_sparsity is not None:
warnings.warn("ch_sparsity is deprecated in v1.3.0. Please use pruning_ratio.")
pruning_ratio = ch_sparsity
if ch_sparsity_dict is not None:
warnings.warn("ch_sparsity_dict is deprecated in v1.3.0. Please use pruning_ratio_dict instead.")
pruning_ratio_dict = ch_sparsity_dict
self.pruning_ratio = pruning_ratio
self.pruning_ratio_dict = pruning_ratio_dict if pruning_ratio_dict is not None else {}
self.max_pruning_ratio = max_pruning_ratio
self.global_pruning = global_pruning
self.isomorphic = isomorphic
if len(channel_groups) > 0:
warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.")
out_channel_groups.update(channel_groups)
if len(num_heads) > 0:
out_channel_groups.update(num_heads)
self.in_channel_groups = in_channel_groups
self.out_channel_groups = out_channel_groups
self.root_module_types = root_module_types
self.round_to = round_to
# MHA
self.num_heads = num_heads
self.prune_num_heads = prune_num_heads
self.prune_head_dims = prune_head_dims
self.head_pruning_ratio = head_pruning_ratio
###############################################
# Ignored layers and submodules
self.ignored_layers = []
self.ignored_params = []
if ignored_layers is not None:
for layer in ignored_layers:
if isinstance(layer, nn.Module):
self.ignored_layers.extend(list(layer.modules()))
elif isinstance(layer, nn.Parameter):
self.ignored_params.append(layer)
###############################################
# Build dependency graph
self.DG = dependency.DependencyGraph().build_dependency(
model,
example_inputs=example_inputs,
forward_fn=forward_fn,
output_transform=output_transform,
unwrapped_parameters=unwrapped_parameters,
customized_pruners=customized_pruners,
ignored_params=self.ignored_params,
)
###############################################
# Iterative pruning
# The pruner will prune the model iteratively for several steps to achieve the target pruning ratio
# E.g., if iterative_steps=5, pruning_ratio=0.5, the pruning ratio of each step will be [0.1, 0.2, 0.3, 0.4, 0.5]
self.iterative_steps = iterative_steps
self.iterative_pruning_ratio_scheduler = iterative_pruning_ratio_scheduler
self.current_step = 0
# channel pruning ratio for each iterative step
self.per_step_pruning_ratio = self.iterative_pruning_ratio_scheduler(
self.pruning_ratio, self.iterative_steps
)
self.per_step_head_pruning_ratio = self.iterative_pruning_ratio_scheduler(
self.head_pruning_ratio, self.iterative_steps
)
###############################################
# Ranking Scopes
# We will perform ranking within each scope.
# If a scope only contains one layer, then we do local pruning
# If a scope contains multiple layers, then global ranking will be applied to the entire scope
# To manually specify the ranking scope, you can use pass a key-value pair to the pruning_ratio_dict, with a tuple of modules as the key.
self._layer_to_scope = {}
self._scope_initial_channels = {} # initial channels for different scope. It will be filled during the first pruning step.
###############################################
# Layer-specific pruning ratios. Will cover the global ratio if specified
# The key of the dict can be a single module or a tuple of modules. The pruning ratio will be shared by all modules in the tuple.
self.pruning_ratio_dict = {}
user_defined_scope_id = 0
if pruning_ratio_dict is not None:
for modules in pruning_ratio_dict:
ratio = pruning_ratio_dict[modules]
if isinstance(modules, tuple):
scope = modules # will scan all modules sequentially
else:
scope = [modules] # only one model, do local pruning for this module
scope_name = f"_User_Defined_Scope_{user_defined_scope_id}"
local_pruning_scope_postfix = 0
for m in scope:
for submodule in m.modules():
prunable_types = tuple([ops.type2class(
prunable_type) for prunable_type in self.DG.REGISTERED_PRUNERS.keys()])
if isinstance(submodule, prunable_types):
if isinstance(submodule, nn.Module):
if not self.global_pruning:
self._layer_to_scope[submodule] = (scope_name+f"_{local_pruning_scope_postfix}", scope)
local_pruning_scope_postfix+=1 # assign each layer to a unique scope if local pruning
else:
self._layer_to_scope[submodule] = (scope_name, scope) # assign all layers to this scope
self.pruning_ratio_dict[submodule] = self.iterative_pruning_ratio_scheduler(
ratio, self.iterative_steps
)
user_defined_scope_id+=1
# Head pruning ratio
self.head_pruning_ratio_dict = {}
if head_pruning_ratio_dict is not None:
for module in head_pruning_ratio_dict:
ratio = head_pruning_ratio_dict[module]
for submodule in module.modules():
prunable_types = tuple([ops.type2class(
prunable_type) for prunable_type in self.DG.REGISTERED_PRUNERS.keys()])
if isinstance(submodule, prunable_types):
self.head_pruning_ratio_dict[submodule] = self.iterative_pruning_ratio_scheduler(
ratio, self.iterative_steps
)
###############################################
# Detect group convs & group norms
for m in self.model.modules():
layer_pruner = self.DG.get_pruner_of_module(m)
in_ch_group = layer_pruner.get_in_channel_groups(m)
out_ch_group = layer_pruner.get_out_channel_groups(m)
if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels:
continue
if in_ch_group > 1:
self.in_channel_groups[m] = in_ch_group
if out_ch_group > 1:
self.out_channel_groups[m] = out_ch_group
###############################################
# Initial channels/dims of each layer
self.layer_init_out_ch = {}
self.layer_init_in_ch = {}
self.init_num_heads = {}
for m in self.DG.module2node.keys():
if ops.module2type(m) in self.DG.REGISTERED_PRUNERS:
self.layer_init_out_ch[m] = self.DG.get_out_channels(m)
self.layer_init_in_ch[m] = self.DG.get_in_channels(m)
if m in self.num_heads:
self.init_num_heads[m] = self.num_heads[m]
###############################################
# Count the number of total channels at initialization
#if self.global_pruning:
initial_total_channels = 0
initial_total_heads = 0
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
_is_atten, qkv_layers = self._is_atten_group(group)
if _is_atten:
group = self._downstream_node_as_root_if_attention(group)
if group is None: continue
initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
for dep, _ in group:
if dep.target.module in self.num_heads and self.DG.is_out_channel_pruning_fn(dep.handler):
initial_total_heads += self.num_heads[dep.target.module]
break # only count heads once
self.initial_total_channels = initial_total_channels
self.initial_total_heads = initial_total_heads
def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
self.current_step += 1
if interactive: # yield groups for interactive pruning
return self._prune()
else:
for group in self._prune():
group.prune()
def manual_prune(self, layer, pruning_fn, pruning_ratios_or_idxs):
if isinstance(pruning_ratios_or_idxs, float):
if self.DG.is_out_channel_pruning_fn(pruning_fn):
prunable_channels = self.DG.get_out_channels(layer)
else:
prunable_channels = self.DG.get_in_channels(layer)
full_group = self.DG.get_pruning_group(layer, pruning_fn, list(range(prunable_channels)))
imp = self.estimate_importance(full_group)
imp_argsort = torch.argsort(imp)
n_pruned = int(prunable_channels * (1 - pruning_ratios_or_idxs))
pruning_idxs = imp_argsort[:n_pruned]
group = self.DG.get_pruning_group(layer, pruning_fn, pruning_idxs)
group.prune()
def estimate_importance(self, group) -> torch.Tensor:
return self.importance(group)
def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]:
return self.DG.pruning_history()
def load_pruning_history(self, pruning_history) -> None:
self.DG.load_pruning_history(pruning_history)
def get_target_pruning_ratio(self, module, step=-1) -> float:
if step<0: step = self.current_step
s = self.pruning_ratio_dict.get(module, self.per_step_pruning_ratio)[step]
return min(s, self.max_pruning_ratio)
def get_target_head_pruning_ratio(self, module) -> float:
s = self.head_pruning_ratio_dict.get(module, self.per_step_head_pruning_ratio)[self.current_step]
return min(s, 1)
def reset(self) -> None:
self.current_step = 0
def update_regularizer(self) -> None:
pass
def regularize(self, model, loss) -> typing.Any:
""" Model regularizer for sparse training
"""
pass
def _check_pruning_ratio(self, group) -> bool:
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
if dep.target.type == ops.OPTYPE.PARAMETER:
continue
if self.DG.is_out_channel_pruning_fn(pruning_fn):
layer_out_ch = self.DG.get_out_channels(module)
if layer_out_ch is None: continue
if layer_out_ch < self.layer_init_out_ch[module] * (
1 - self.max_pruning_ratio
) or layer_out_ch == 1:
return False
elif self.DG.is_in_channel_pruning_fn(pruning_fn):
layer_in_ch = self.DG.get_in_channels(module)
if layer_in_ch is None: continue
if layer_in_ch < self.layer_init_in_ch[module] * (
1 - self.max_pruning_ratio
) or layer_in_ch == 1:
return False
return True
def _is_atten_group(self, group) -> bool:
is_attn = False
qkv_layers = []
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
if self.DG.is_out_channel_pruning_fn(pruning_fn) and module in self.num_heads:
qkv_layers.append(module)
is_attn = True
return is_attn, qkv_layers
def _get_channel_groups(self, group) -> int:
ch_groups = []
#has_unbind = False
#unbind_node = None
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
channel_groups = self.out_channel_groups if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.in_channel_groups
if module in channel_groups:
ch_groups.append(channel_groups[module])
#if dep.source.type==ops.OPTYPE.UNBIND:
# has_unbind = True
# unbind_node = dep.source
#if has_unbind and ch_groups>1:
# ch_groups = ch_groups // len(unbind_node.outputs)
if len(ch_groups) == 0:
return 1
return max(ch_groups) # no channel grouping
def _downstream_node_as_root_if_attention(self, group):
# Use a downstream node as the root if torch.unbind exists. TODO: find a general way to handle torch.unbind in timm
is_attention = False
downstream_dep = None
for _dep, _idxs in group:
if _dep.source.module in self.num_heads and self.DG.is_out_channel_pruning_fn(_dep.handler):
is_attention = True
if isinstance(_dep.target.module, tuple(self.root_module_types)) and self.DG.is_in_channel_pruning_fn(_dep.handler):
downstream_dep = _dep
idxs = _idxs
if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, idxs)
return group
return None
def _round_to(self, n_pruned, current_channels, round_to):
rounded_channels = current_channels - n_pruned
rounded_channels = rounded_channels - rounded_channels % round_to
n_pruned = current_channels - rounded_channels
return max(n_pruned, 0)
@torch.no_grad()
def _prune(self) -> typing.Generator:
if self.current_step > self.iterative_steps:
warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.")
return
##############################################
# Initialize ranking scopes
# A scope is a set of layers that will be ranked together to determine their relative importance.
# This feature is useful for implementing ranking strategies such as local pruning, global pruning, customized pruning ratios or isomorphic pruning (ECCV 2024): https://arxiv.org/abs/2407.04616
# There are two pre-defined scopes: DEFAULT_SCOPE and ATTN_HEAD_SCOPE
# - DEFAULT_SCOPE: a group will be assigned to this scope for global ranking if not specified
# - ATTN_HEAD_SCOPE: for multi-head attention pruning
##############################################
DEFAULT_SCOPE = "DEFAULT_SCOPE"
ATTN_HEAD_SCOPE = "ATTN_HEAD_SCOPE"
ranking_scope = {DEFAULT_SCOPE: [], ATTN_HEAD_SCOPE: {}} # ATTN_HEAD_SCOPE will be a dict, because we need to index these groups later
##############################################
# 1. Pre-compute importance for each group and assign them to different scopes
##############################################
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
if self._check_pruning_ratio(group):
# Re-order the group and use a downstream node as the root node for attention layers.
# This will not change the group structure, but make index mapping easier for attention layers.
_is_atten, qkv_layers = self._is_atten_group(group)
if _is_atten:
group = self._downstream_node_as_root_if_attention(group)
if group is None: continue
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group) # raw importance score
group_size = len(imp) // ch_groups
if imp is None: continue
if ch_groups > 1: # layers with dimension grouping, such as GroupConv, GroupNorm, Multi-head attention, etc.
# We average importance across groups here. For example:
# imp = [1, 2, 3, 4, 5, 6] with ch_groups=2.
# We have two groups [1,2,3] and [4,5,6].
# The average importance should be [(1+4)/2, (2+5)/2, (3+6)/2] = [2.5, 3.5, 4.5]
dim_imp = imp.view(ch_groups, -1).mean(dim=0).cpu()
else:
# no grouping
dim_imp = imp.cpu()
# Importance scores for Attention Heads
_is_atten, qkv_layers = self._is_atten_group(group)
if _is_atten and self.prune_num_heads and self.get_target_head_pruning_ratio(qkv_layers[0])>0:
# average importance over heads
# Example: if we have the importance score:
# imp = [1, 2, 3, 4, 5, 6] with num_heads=2
# Note: head1 = [1, 2, 3], head2 = [4, 5, 6]
# the average importance is [(1+2+3)/3, (4+5+6)/3] = [2, 5]
# GQA: the number of heads for KV might be different from Q
num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers]) # get the maximum number of heads
head_imp = imp.view(num_heads, -1).mean(1).cpu() # average importance by head.
ranking_scope[ATTN_HEAD_SCOPE][group] = (qkv_layers, head_imp)
# Scope 1: User-defined scope, such as layer-wise pruning_ratios
is_user_defined_scope = False
for dep, _ in group:
for module, pruning_fn in zip([dep.source.module, dep.target.module], [dep.trigger, dep.handler]):
if module in self._layer_to_scope and self.DG.is_out_channel_pruning_fn(pruning_fn):
scope_name, scope = self._layer_to_scope[module]
if len(scope)>0:
pruning_ratio = self.get_target_pruning_ratio(module, step=self.current_step)
record = (group, ch_groups, group_size, pruning_ratio, dim_imp)
if scope_name not in ranking_scope:
ranking_scope[scope_name] = []
ranking_scope[scope_name].append(record)
is_user_defined_scope = True
# A bit messy here. Will refactor in the future.
if is_user_defined_scope: break
if is_user_defined_scope: break
if is_user_defined_scope:
continue
record = (group, ch_groups, group_size, self.per_step_pruning_ratio[self.current_step], dim_imp) # otherwise, use the default pruning ratio
# Scope 2: Isomorphic Pruning
if self.isomorphic:
scope_name = "Isomorphic_" # we transform the graph structure into a string tag for easy comparison
for dep, _ in group: # if isomorphic, the source and target modules should have the same **layer type** and **pruning function**
source = "%s_%s"%(type(dep.source.module), "out" if self.DG.is_out_channel_pruning_fn(dep.handler) else "in")
target = "%s_%s"%(type(dep.target.module), "out" if self.DG.is_out_channel_pruning_fn(dep.handler) else "in")
scope_name += "%s_%s"%(source, target)
if scope_name not in ranking_scope:
# New isomorphic group
ranking_scope[scope_name] = []
ranking_scope[scope_name].append(record)
elif self.global_pruning: # Scope 3: use the default scope for global pruning
ranking_scope[DEFAULT_SCOPE].append(record)
else: # Scope 4: always create a new scope if local pruning
module_name = self.DG._module2name[group[0][0].source.module]
ranking_scope[module_name] = [ record ]
if len(ranking_scope[DEFAULT_SCOPE]) == 0 and len(ranking_scope[ATTN_HEAD_SCOPE])==0 and len(ranking_scope)<=2:
return
##############################################
# 2. Thresholding by ranking all importance scores within each scope
##############################################
# Find the threshold for the Multi-head attention scope if global pruning is enabled
if len(ranking_scope[ATTN_HEAD_SCOPE])>0 and self.global_pruning:
concat_head_imp = torch.cat([local_imp[-1] for local_imp in ranking_scope[ATTN_HEAD_SCOPE].values()], dim=0)
target_head_pruning_ratio = self.per_step_head_pruning_ratio[self.current_step]
n_heads_removed = len(concat_head_imp) - int(
self.initial_total_heads *
(1 - target_head_pruning_ratio)
)
if n_heads_removed>0:
topk_head_imp, _ = torch.topk(concat_head_imp, k=n_heads_removed, largest=False)
head_thres = topk_head_imp[-1]
# Width pruning
width_pruning_scope_names = [ k for k in ranking_scope.keys() if k!=ATTN_HEAD_SCOPE]
for scope_id, scope_name in enumerate(width_pruning_scope_names):
if not self.global_pruning:
assert len(ranking_scope[scope_name])<=1, "Internal Error: local pruning should only contain less than one layer per scope."
records = ranking_scope[scope_name] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i
# Find the threshold for pruning
if len(records)>0:
concat_imp = torch.cat([local_imp[-1] for local_imp in records], dim=0) # concatenate importance scores in this scope
target_pruning_ratio = records[0][-2] # records[i] -> (group, ch_groups, group_size, pruning_ratio, dim_imp)_i
if scope_name not in self._scope_initial_channels:
self._scope_initial_channels[scope_name] = len(concat_imp)
n_pruned = len(concat_imp) - int(
self._scope_initial_channels[scope_name] *
(1 - target_pruning_ratio)
)
if n_pruned>0:
topk_imp, topk_indices = torch.topk(concat_imp, k=n_pruned, largest=False)
thres = topk_imp[-1]
##############################################
# 3. Pruning in each scope
##############################################
for group, ch_groups, group_size, target_pruning_ratio, imp in records:
module = group[0].dep.target.module
pruning_fn = group[0].dep.handler
get_channel_fn = self.DG.get_out_channels if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.DG.get_in_channels
_is_atten, qkv_layers = self._is_atten_group(group)
# Prune dims/channels
pruning_indices = []
if not _is_atten or self.prune_head_dims:
if self.global_pruning:
_pruning_indices = (imp <= thres).nonzero().view(-1)
else:
_pruning_indices = topk_indices
imp_argsort = torch.argsort(imp)
if len(_pruning_indices)>0 and self.round_to: # recompute the number of pruned channels if round_to is enabled
n_pruned = len(_pruning_indices)
current_channels = get_channel_fn(module)
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
_pruning_indices = imp_argsort[:n_pruned]
if ch_groups>1: # if channel grouping is enabled, we repeat the pruning indices for each channel group
for g_id in range(ch_groups):
pruning_indices.append(_pruning_indices+g_id*group_size)
else:
pruning_indices.append(_pruning_indices)
# Prune Attention Heads
if len(ranking_scope[ATTN_HEAD_SCOPE])>0:
if group in ranking_scope[ATTN_HEAD_SCOPE]:
qkv_layers, head_imp = ranking_scope[ATTN_HEAD_SCOPE][group]
num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
_is_gqa = not all([self.num_heads[qkv_layer]==num_heads for qkv_layer in qkv_layers])
if not self.global_pruning: # local pruning
n_heads_removed_per_group = int(self.get_target_head_pruning_ratio(qkv_layers[0]) * len(head_imp))
if not _is_gqa:
head_pruning_indices = torch.topk(head_imp, k=n_heads_removed_per_group, largest=False)[1] # local ranking
else: # chunk the head imp
num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
n_heads_removed_per_group = n_heads_removed_per_group // num_kv_heads
head_pruning_indices = []
for kv_head_id in range(num_kv_heads):
head_imp_kv = head_imp[kv_head_id * num_heads//num_kv_heads: (kv_head_id+1) * num_heads//num_kv_heads]
head_pruning_indices_kv = torch.topk(head_imp_kv, k=n_heads_removed_per_group, largest=False)[1]
head_pruning_indices.append(head_pruning_indices_kv + kv_head_id*num_heads//num_kv_heads)
head_pruning_indices = torch.cat(head_pruning_indices, 0)
else: # global pruning
head_pruning_indices = (head_imp <= head_thres).nonzero().view(-1) # global ranking
if _is_gqa:
num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
n_heads_removed_per_group = len(head_pruning_indices) // num_kv_heads
head_pruning_indices = []
for kv_head_id in range(num_kv_heads):
head_imp_kv = head_imp[kv_head_id * len(head_imp)//num_kv_heads: (kv_head_id+1) * len(head_imp)//num_kv_heads]
head_pruning_indices_kv = torch.topk(head_imp_kv, k=n_heads_removed_per_group, largest=False)[1]
head_pruning_indices.append(head_pruning_indices_kv + kv_head_id*num_kv_heads)
head_pruning_indices = torch.cat(head_pruning_indices, 0)
if len(head_pruning_indices)>0:
if len(qkv_layers)==1:
head_dim = qkv_layers[0].out_features // (self.num_heads[qkv_layers[0]]*3)
else:
head_dim = qkv_layers[0].out_features // self.num_heads[qkv_layers[0]]
for head_id in head_pruning_indices:
pruning_indices.append( torch.arange(head_id*head_dim, (head_id+1)*head_dim, device=head_imp.device) )
num_heads = max([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
for qkv_layer in qkv_layers:
if self.num_heads[qkv_layer] == num_heads:
self.num_heads[qkv_layer] -= len(head_pruning_indices) # update num heads after pruning
self.out_channel_groups[qkv_layer] = self.num_heads[qkv_layer] # update out_channel_groups
if len(pruning_indices)==0: continue
pruning_indices = torch.unique(torch.cat(pruning_indices, 0)).tolist()
if isinstance(self.importance, OBDCImportance):
self.importance.adjust_fisher(group, pruning_indices)
# create pruning group
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices)
if _is_atten:
_is_gqa = not all([self.num_heads[qkv_layer]==self.num_heads[qkv_layers[0]] for qkv_layer in qkv_layers])
if _is_gqa and self.prune_num_heads:
num_kv_heads = min([self.num_heads[qkv_layer] for qkv_layer in qkv_layers])
kv_layers = [qkv_layer for qkv_layer in qkv_layers if self.num_heads[qkv_layer]==num_kv_heads]
for i in range(len(group)):
dep, idxs = group[i]
if dep.target.module in kv_layers:
group[i] = (dep, []) # disable pruning for the kv layers if GQA is enabled
if self.DG.check_pruning_group(group):
yield group # yield the group for interactive pruning