Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
samanhappy committed Aug 23, 2024
1 parent f1650f3 commit 2e2c51b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 61 deletions.
53 changes: 2 additions & 51 deletions monkey/monkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package monkey
import (
"fmt"
"reflect"
"sync"
"time"

"github.com/agiledragon/gomonkey/v2"
Expand All @@ -29,38 +28,25 @@ import (

var (
patchesMap = make(map[string]*gomonkey.Patches)
mu sync.RWMutex // 使用读写锁来保证线程安全
)

// Patch replaces a function with another
func Patch(target, replacement interface{}) *gomonkey.Patches {
key := fmt.Sprintf("%v", target)
mu.Lock()
defer mu.Unlock()

if existingPatches, ok := patchesMap[key]; ok {
log.Infof("reset existing patches for %v", key)
existingPatches.Reset()
delete(patchesMap, key)
time.Sleep(100 * time.Millisecond)
}

patches, err := safeApplyFunc(target, replacement)
if err != nil {
log.Errorf("failed to apply patch for %v: %v", key, err)
return nil
}

patches := gomonkey.ApplyFunc(target, replacement)
patchesMap[key] = patches
return patches
}

// Unpatch unpatch a patch
func Unpatch(target interface{}) bool {
key := fmt.Sprintf("%v", target)
mu.Lock()
defer mu.Unlock()

patches, ok := patchesMap[key]
if !ok {
return false
Expand All @@ -74,22 +60,13 @@ func Unpatch(target interface{}) bool {
// PatchInstanceMethod replaces an instance method methodName for the type target with replacement
func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *gomonkey.Patches {
key := fmt.Sprintf("%v:%v", target, methodName)
mu.Lock()
defer mu.Unlock()

if existingPatches, ok := patchesMap[key]; ok {
log.Infof("reset existing patches %v for %v", existingPatches, key)
existingPatches.Reset()
delete(patchesMap, key)
time.Sleep(100 * time.Millisecond)
}

patches, err := safeApplyMethod(target, methodName, replacement)
if err != nil {
log.Errorf("failed to apply patch for %v: %v", key, err)
return nil
}

patches := gomonkey.ApplyMethod(target, methodName, replacement)
patchesMap[key] = patches
log.Infof("patchesMap: %v", patchesMap)
return patches
Expand All @@ -98,9 +75,6 @@ func PatchInstanceMethod(target reflect.Type, methodName string, replacement int
// UnpatchInstanceMethod unpatch a patch
func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {
key := fmt.Sprintf("%v:%v", target, methodName)
mu.Lock()
defer mu.Unlock()

patches, ok := patchesMap[key]
if !ok {
return false
Expand All @@ -113,32 +87,9 @@ func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {

// UnpatchAll unpatch all patches
func UnpatchAll() {
mu.Lock()
defer mu.Unlock()

for key, patches := range patchesMap {
patches.Reset()
delete(patchesMap, key)
}
time.Sleep(100 * time.Millisecond)
}

// safeApplyFunc safely applies a function patch and handles potential panics
func safeApplyFunc(target, replacement interface{}) (*gomonkey.Patches, error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic while applying function patch: %v", r)
}
}()
return gomonkey.ApplyFunc(target, replacement), nil
}

// safeApplyMethod safely applies a method patch and handles potential panics
func safeApplyMethod(target reflect.Type, methodName string, replacement interface{}) (*gomonkey.Patches, error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic while applying method patch: %v", r)
}
}()
return gomonkey.ApplyMethod(target, methodName, replacement), nil
}
21 changes: 11 additions & 10 deletions monkey/monkey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ func (s *MyStruct) Method() string {
return "original"
}

func TestUnpatch(t *testing.T) {
originalFunc := func() string { return "original" }
replacementFunc := func() string { return "replacement" }
func TestPatch(t *testing.T) {
original := func() string { return "original" }
replacement := func() string { return "replacement" }

Patch(originalFunc, replacementFunc)
Patch(original, replacement)
assert.Equal(t, "replacement", original())

Unpatch(originalFunc)
assert.Equal(t, "original", originalFunc())
Unpatch(original)
assert.Equal(t, "original", original())
}

func TestPatchInstanceMethod(t *testing.T) {
Expand All @@ -54,13 +55,13 @@ func TestUnpatchInstanceMethod(t *testing.T) {
}

func TestUnpatchAll(t *testing.T) {
originalFunc := func() string { return "original" }
replacementFunc := func() string { return "replacement" }
original := func() string { return "original" }
replacement := func() string { return "replacement" }

Patch(originalFunc, replacementFunc)
Patch(original, replacement)
PatchInstanceMethod(reflect.TypeOf(&MyStruct{}), "Method", func(*MyStruct) string { return "replacement" })
UnpatchAll()

assert.Equal(t, "original", originalFunc())
assert.Equal(t, "original", original())
assert.Equal(t, "original", (&MyStruct{}).Method())
}

0 comments on commit 2e2c51b

Please sign in to comment.