您的位置  > 互联网

训练网络权重和缩放因子的通道直接移除,目标函数

其中(x,y)代表训练数据和标签,是网络的可训练参数。 第一项是CNN的训练损失函数。 是缩放因子的乘法项,并且是两项的平衡因子。 论文的实验过程选择为正则化,这在稀疏化中也被广泛使用。 次梯度下降法是针对不光滑(不可微)L1惩罚项的优化方法。 另一个建议是使用平滑的 L1 正则项而不是 L1 惩罚项,并尽量避免在不平滑点使用次梯度。

这里的缩放因子是BN层的gamma参数。

train.py 的实现支持稀疏训练。 下面两行代码添加了稀疏训练的稀疏系数。 请注意,它们作用于 BN 层的缩放系数:

parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
                        help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.01, help='scale sparse rate') 


class BNOptimizer():
    @staticmethod
    def updateBN(sr_flag, module_list, s, prune_idx):
        if sr_flag:
            for idx in prune_idx:
                # Squential(Conv, BN, Lrelu)
                bn_module = module_list[idx][1]
                bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data))  # L1

关联