Skip to content

Commit

Permalink
Revert "[Feature] Add GPTQ and uniform interfaces (#538)"
Browse files Browse the repository at this point in the history
This reverts commit dcf7bfa.
  • Loading branch information
humu789 committed May 25, 2023
1 parent 56c49e6 commit fb113d2
Show file tree
Hide file tree
Showing 32 changed files with 139 additions and 2,418 deletions.
4 changes: 2 additions & 2 deletions mmrazor/implementations/pruning/sparse_gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import SparseGptCompressor
from .mutator import SparseGptMutator
from .ops import SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops

__all__ = [
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
'SparseGptCompressor'
'SparseGptMutator'
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,24 @@


def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
load_fix_subnet(model, fix_subnet)
return model


class SparseGptCompressor():
"""The compressor with SparseGPT."""
class SparseGptMutator():

# init

def __init__(self) -> None:
self.model: nn.Module = None

def prepare(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
"""Prepare for compressing model."""
def prepare_from_supernet(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
self.model = model
prune_modules: dict = {}
if prune_conv:
Expand All @@ -37,23 +36,19 @@ def prepare(self,

@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)

# hessian

def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
def start_init_hessian(self):
for module in self.sparse_ops:
module.register_hessian_hook()
module.start_init_hessian()

def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
def end_init_hessian(self):
for module in self.sparse_ops:
module.remove_hessian_hook()
module.end_init_hessian()

def init_hessian(self, device=None):
"""Init hessian."""
for op in self.sparse_ops:
op.init_hessian(device=device)

Expand All @@ -65,7 +60,6 @@ def prune(self,
blocksize=128,
percdamp=.01,
device=torch.device('cuda')):
"""Apply the compression algorithm to the model."""
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
Expand All @@ -84,23 +78,19 @@ def prune(self,
print_log(f'prune {name} failed as {e}')

def prune_24(self, device=torch.device('cuda:0')):
"""Apply the compression algorithm to the model with the specified
setting."""
self.prune(0.5, prunen=2, prunem=4, device=device)

# ops

@property
def sparse_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
yield module

@property
def named_sparse_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module
30 changes: 6 additions & 24 deletions mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys

if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
from typing import Protocol

import torch
import torch.distributed as dist
Expand All @@ -17,10 +12,10 @@


class SparseGptMixIn(ModuleProtocol):
"""The core algorithm implementation for SparseGpt."""

# init

def _sparse_gpt_mix_in_init(self):
"""Init mixin."""
self.sparse_gpt_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
Expand All @@ -37,7 +32,6 @@ def weight_matrix(self):

@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
Expand Down Expand Up @@ -70,7 +64,6 @@ def hessian(self):

@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
Expand All @@ -84,7 +77,6 @@ def hessian(self, value: torch.Tensor):

@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
Expand All @@ -102,8 +94,7 @@ def update_hessian(self, input: torch.Tensor):
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B

def register_hessian_hook(self):
"""Register updating hessian hook."""
def start_init_hessian(self):

@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
Expand All @@ -113,13 +104,11 @@ def forward_pre_hook(module: Protocol, input: tuple):
handle = self.register_forward_pre_hook(forward_pre_hook)
self.sparse_gpt_handles.append(handle)

def remove_hessian_hook(self):
"""Remove updating hessian hook."""
def end_init_hessian(self):
for h in self.sparse_gpt_handles:
h.remove()

def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
Expand All @@ -136,7 +125,6 @@ def init_hessian(self, device=None):

@torch.no_grad()
def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
"""The implementation for SparseGPT."""
with torch_setting(dtype=torch.float):
# Converted from https://github.com/ist-daslab/sparsegpt

Expand Down Expand Up @@ -211,8 +199,7 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

if W.device.type == 'cuda':
torch.cuda.synchronize()
torch.cuda.synchronize()
from .sparse24_utils import is_weight_sparse_24
if prunen == 2 and prunem == 4:
assert is_weight_sparse_24(
Expand All @@ -231,15 +218,13 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):


class SparseGptLinear(DynamicLinear, SparseGptMixIn):
"""Custom Linear for SparseGpt."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
if module.out_features < module.in_features:
return module
new_module = super().convert_from(module)
Expand All @@ -252,15 +237,13 @@ def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':


class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
"""Custom Conv2d for SparseGpt."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

Expand All @@ -270,7 +253,6 @@ def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
return new_module

def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
Expand Down
26 changes: 3 additions & 23 deletions mmrazor/implementations/pruning/sparse_gpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from typing import Dict, Type

if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
from typing import Dict, Protocol, Type

import torch
import torch.nn as nn
Expand All @@ -15,27 +9,21 @@


class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor

def forward(self, x):
"""The abstract method."""
pass

def register_forward_hook(self, hook):
"""The abstract method."""
pass

def register_backward_hook(self, hook):
"""The abstract method."""
pass

def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass

def register_buffer(self, name, tensor):
"""The abstract method."""
pass


Expand All @@ -59,7 +47,6 @@ def replace_op(model: nn.Module, name: str, module: nn.Module):

def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):
"""Register efficient forward hook."""

def forward_pre_hook(module: nn.Module, input):
module.to(device)
Expand All @@ -76,7 +63,6 @@ def forward_hook(module: nn.Module, input, output):
def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
"""Enable efficient forward."""
handles = []
blocks = []
for name, module in model.named_children():
Expand All @@ -93,7 +79,6 @@ def enable_efficient_forward(model: nn.Module,


class memory_efficient_forward:
"""The class for Memory efficient forward."""

def __init__(self,
model: nn.Module,
Expand All @@ -110,31 +95,26 @@ def __init__(self,
model.to(device)

def __enter__(self, ):
"""Enter."""
if self.enabled:
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
print_log(f'enable memory efficient forward for {blocks}')
self.handlers = handles

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
for h in self.handlers:
h.remove()


class torch_setting():
"""Set the default torch dtype setting."""

def __init__(self, dtype=None) -> None:
self.original_dtype = torch.get_default_dtype()
self.origianl_dtype = torch.get_default_dtype()
self.dtype = dtype

def __enter__(self):
"""Enter."""
if self.dtype is not None:
torch.set_default_dtype(self.dtype)

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
torch.set_default_dtype(self.original_dtype)
torch.set_default_dtype(self.origianl_dtype)
14 changes: 0 additions & 14 deletions mmrazor/implementations/quantization/gptq/__init__.py

This file was deleted.

Loading

0 comments on commit fb113d2

Please sign in to comment.