From 282ca925d37cd4f7433aa68c831d92bfe9aaa73b Mon Sep 17 00:00:00 2001 From: libi <7769922+libi@users.noreply.github.com> Date: Thu, 1 Dec 2022 14:43:42 +0800 Subject: [PATCH] fix start/stop dcron data race --- dcron.go | 86 ++++++++++++++++++++++++++++++++------------------- dcron_test.go | 15 +++++++-- node_pool.go | 7 +++-- 3 files changed, 71 insertions(+), 37 deletions(-) diff --git a/dcron.go b/dcron.go index 2139924..8408b7e 100644 --- a/dcron.go +++ b/dcron.go @@ -6,18 +6,26 @@ import ( "github.com/robfig/cron/v3" "log" "os" + "sync/atomic" "time" ) -const defaultReplicas = 50 -const defaultDuration = time.Second +const ( + defaultReplicas = 50 + defaultDuration = time.Second +) + +const ( + dcronRunning = 1 + dcronStoped = 0 +) -//Dcron is main struct +// Dcron is main struct type Dcron struct { jobs map[string]*JobWarpper ServerName string nodePool *NodePool - isRun bool + running int32 logger interface{ Printf(string, ...interface{}) } @@ -28,16 +36,17 @@ type Dcron struct { crOptions []cron.Option } -//NewDcron create a Dcron +// NewDcron create a Dcron func NewDcron(serverName string, driver driver.Driver, cronOpts ...cron.Option) *Dcron { dcron := newDcron(serverName) dcron.crOptions = cronOpts dcron.cr = cron.New(cronOpts...) + dcron.running = dcronStoped dcron.nodePool = newNodePool(serverName, driver, dcron, dcron.nodeUpdateDuration, dcron.hashReplicas) return dcron } -//NewDcronWithOption create a Dcron with Dcron Option +// NewDcronWithOption create a Dcron with Dcron Option func NewDcronWithOption(serverName string, driver driver.Driver, dcronOpts ...Option) *Dcron { dcron := newDcron(serverName) for _, opt := range dcronOpts { @@ -60,12 +69,12 @@ func newDcron(serverName string) *Dcron { } } -//SetLogger set dcron logger +// SetLogger set dcron logger func (d *Dcron) SetLogger(logger *log.Logger) { d.logger = logger } -//GetLogger get dcron logger +// GetLogger get dcron logger func (d *Dcron) GetLogger() interface{ Printf(string, ...interface{}) } { return d.logger } @@ -77,12 +86,12 @@ func (d *Dcron) err(format string, v ...interface{}) { d.logger.Printf("ERR: "+format, v...) } -//AddJob add a job +// AddJob add a job func (d *Dcron) AddJob(jobName, cronStr string, job Job) (err error) { return d.addJob(jobName, cronStr, nil, job) } -//AddFunc add a cron func +// AddFunc add a cron func func (d *Dcron) AddFunc(jobName, cronStr string, cmd func()) (err error) { return d.addJob(jobName, cronStr, cmd, nil) } @@ -126,36 +135,51 @@ func (d *Dcron) allowThisNodeRun(jobName string) bool { return d.nodePool.NodeID == allowRunNode } -//Start start job +// Start job func (d *Dcron) Start() { - d.isRun = true - err := d.nodePool.StartPool() - if err != nil { - d.isRun = false - d.err("dcron start node pool error %+v", err) - return + if atomic.CompareAndSwapInt32(&d.running, dcronStoped, dcronRunning) { + if err := d.startNodePool(); err != nil { + atomic.StoreInt32(&d.running, dcronStoped) + return + } + d.cr.Start() + d.info("dcron started , nodeID is %s", d.nodePool.NodeID) + } else { + d.info("dcron have started") } - d.cr.Start() - d.info("dcron started , nodeID is %s", d.nodePool.NodeID) } // Run Job func (d *Dcron) Run() { - d.isRun = true - err := d.nodePool.StartPool() - if err != nil { - d.isRun = false - d.err("dcron start node pool error %+v", err) - return + if atomic.CompareAndSwapInt32(&d.running, dcronStoped, dcronRunning) { + if err := d.startNodePool(); err != nil { + atomic.StoreInt32(&d.running, dcronStoped) + return + } + + d.info("dcron running nodeID is %s", d.nodePool.NodeID) + d.cr.Run() + } else { + d.info("dcron already running") } - d.info("dcron running nodeID is %s", d.nodePool.NodeID) - d.cr.Run() +} +func (d *Dcron) startNodePool() error { + if err := d.nodePool.StartPool(); err != nil { + d.err("dcron start node pool error %+v", err) + return err + } + return nil } -//Stop stop job +// Stop job func (d *Dcron) Stop() { - d.isRun = false - d.cr.Stop() - d.info("dcron stopped") + for { + if atomic.CompareAndSwapInt32(&d.running, dcronRunning, dcronStoped) { + d.cr.Stop() + d.info("dcron stopped") + return + } + time.Sleep(time.Millisecond) + } } diff --git a/dcron_test.go b/dcron_test.go index f17d4a5..4c1f837 100644 --- a/dcron_test.go +++ b/dcron_test.go @@ -23,11 +23,15 @@ var testData = make(map[string]struct{}) func Test(t *testing.T) { - drv, _ := dredis.NewDriver(&dredis.Conf{ + drv, err := dredis.NewDriver(&dredis.Conf{ Host: "127.0.0.1", Port: 6379, }, redis.DialConnectTimeout(time.Second*10)) + if err != nil { + t.Error(err) + } + go runNode(t, drv) // 间隔1秒启动测试节点刷新逻辑 time.Sleep(time.Second) @@ -37,9 +41,11 @@ func Test(t *testing.T) { //add recover dcron2 := NewDcron("server2", drv, cron.WithChain(cron.Recover(cron.DefaultLogger))) + dcron2.Start() + dcron2.Stop() //panic recover test - err := dcron2.AddFunc("s2 test1", "* * * * *", func() { + err = dcron2.AddFunc("s2 test1", "* * * * *", func() { panic("panic test") }) if err != nil { @@ -75,6 +81,7 @@ func Test(t *testing.T) { if err != nil { t.Fatal("add func error") } + err = dcron3.AddFunc("s3 test2", "* * * * *", func() { t.Log("执行 server3 test2 任务", time.Now().Format("15:04:05")) }) @@ -92,6 +99,8 @@ func Test(t *testing.T) { //测试120秒后退出 time.Sleep(120 * time.Second) t.Log("testData", testData) + dcron2.Stop() + dcron3.Stop() } func runNode(t *testing.T, drv *dredis.RedisDriver) { @@ -100,7 +109,7 @@ func runNode(t *testing.T, drv *dredis.RedisDriver) { err := dcron.AddFunc("s1 test1", "* * * * *", func() { // 同时启动3个节点 但是一个 job 同一时间只会执行一次 通过 map 判重 - key := "s1 test1" + time.Now().Format("15:04:05") + key := "s1 test1 : " + time.Now().Format("15:04") if _, ok := testData[key]; ok { t.Error("job have running in other node") } diff --git a/node_pool.go b/node_pool.go index 8b4be67..5db2392 100644 --- a/node_pool.go +++ b/node_pool.go @@ -4,10 +4,11 @@ import ( "github.com/libi/dcron/consistenthash" "github.com/libi/dcron/driver" "sync" + "sync/atomic" "time" ) -//NodePool is a node pool +// NodePool is a node pool type NodePool struct { serviceName string NodeID string @@ -75,7 +76,7 @@ func (np *NodePool) updatePool() error { func (np *NodePool) tickerUpdatePool() { tickers := time.NewTicker(np.updateDuration) for range tickers.C { - if np.dcron.isRun { + if atomic.LoadInt32(&np.dcron.running) == dcronRunning { err := np.updatePool() if err != nil { np.dcron.err("update node pool error %+v", err) @@ -87,7 +88,7 @@ func (np *NodePool) tickerUpdatePool() { } } -//PickNodeByJobName : 使用一致性hash算法根据任务名获取一个执行节点 +// PickNodeByJobName : 使用一致性hash算法根据任务名获取一个执行节点 func (np *NodePool) PickNodeByJobName(jobName string) string { np.mu.Lock() defer np.mu.Unlock()