diff --git a/monkey/monkey.go b/monkey/monkey.go index a396d053..e1319339 100644 --- a/monkey/monkey.go +++ b/monkey/monkey.go @@ -20,6 +20,7 @@ package monkey import ( "fmt" "reflect" + "sync" "time" "github.com/agiledragon/gomonkey/v2" @@ -28,17 +29,22 @@ import ( var ( patchesMap = make(map[string]*gomonkey.Patches) + mu sync.RWMutex // 使用读写锁来保证线程安全 ) // Patch replaces a function with another func Patch(target, replacement interface{}) *gomonkey.Patches { key := fmt.Sprintf("%v", target) + mu.Lock() + defer mu.Unlock() + if existingPatches, ok := patchesMap[key]; ok { log.Infof("reset existing patches for %v", key) existingPatches.Reset() delete(patchesMap, key) time.Sleep(100 * time.Millisecond) } + patches := gomonkey.ApplyFunc(target, replacement) patchesMap[key] = patches return patches @@ -47,6 +53,9 @@ func Patch(target, replacement interface{}) *gomonkey.Patches { // Unpatch unpatch a patch func Unpatch(target interface{}) bool { key := fmt.Sprintf("%v", target) + mu.Lock() + defer mu.Unlock() + patches, ok := patchesMap[key] if !ok { return false @@ -60,12 +69,16 @@ func Unpatch(target interface{}) bool { // PatchInstanceMethod replaces an instance method methodName for the type target with replacement func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *gomonkey.Patches { key := fmt.Sprintf("%v:%v", target, methodName) + mu.Lock() + defer mu.Unlock() + if existingPatches, ok := patchesMap[key]; ok { log.Infof("reset existing patches %v for %v", existingPatches, key) existingPatches.Reset() delete(patchesMap, key) time.Sleep(100 * time.Millisecond) } + patches := gomonkey.ApplyMethod(target, methodName, replacement) patchesMap[key] = patches log.Infof("patchesMap: %v", patchesMap) @@ -75,6 +88,9 @@ func PatchInstanceMethod(target reflect.Type, methodName string, replacement int // UnpatchInstanceMethod unpatch a patch func UnpatchInstanceMethod(target reflect.Type, methodName string) bool { key := fmt.Sprintf("%v:%v", target, methodName) + mu.Lock() + defer mu.Unlock() + patches, ok := patchesMap[key] if !ok { return false @@ -87,6 +103,9 @@ func UnpatchInstanceMethod(target reflect.Type, methodName string) bool { // UnpatchAll unpatch all patches func UnpatchAll() { + mu.Lock() + defer mu.Unlock() + for key, patches := range patchesMap { patches.Reset() delete(patchesMap, key)