「Golang」sync.WaitGroup源码讲解
sync.WaitGroup
介绍
当我们在开发过程中,经常需要在开启多个goroutine后,等待全部的goroutine执行完毕后才进行下一步的业务逻辑执行。此时我们可能会采用轮询的方式去定时侦测已经开启的多个goroutine的业务是否执行完毕,但是这样性能很低,并且持续占用cpu时间片很消耗cpu的资源,此时我们就该使用sync.WaitGroup
来完成此次操作。举个🌰,下列代码是开了10个goroutine后等待其各睡眠5秒之后进行后续操作的sync.WaitGroup
方法实现。
func main() {
// 创建对象
var wait sync.WaitGroup
for i := 0; i < 10; i++ {
// 为需要等待结束的goroutine数量+1
wait.Add(1)
go func() {
time.Sleep(5*time.Second)
// 结束 使得需要等待的数量-1
// 等同于 wait.Add(-1)
wait.Done()
}()
}
// 等待所有执行完毕
wait.Wait()
fmt.Println("wait done")
}
上述代码的睡眠5秒可以替换为任何需要在goroutine中执行的业务逻辑,上述代码中出现了下述几个sync.WaitGroup
中提供的方法,提供方法很少很简洁,接下来就开始解析一下sync.WaitGroup
的相关信息。
func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()
sync.WaitGroup
源代码解析
1:sync.WaitGroup
结构体的解析
type WaitGroup struct {
// 一个防止sync.WaitGroup被复制的标记结构体
noCopy noCopy
// 该数组在32为系统与64位系统中代表的用途不同
// 首先说64位系统:
// state1[0]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
// state1[1]代表调用了Wait方法等待结束的Waiter的数量
// state1[2]代表Waiter的信号量
// 其中 state1[0]与state1[1]作者称为计数标记
// 在32位系统中:
// state1[0]代表Waiter的信号量
// state1[1]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
// state1[2]代表调用了Wait方法等待结束的Waiter的数量
// 其中 state1[1]与state1[2]作者称为计数标记
state1 [3]uint32
}
Add(delta int)
方法源代码解析
// state 方法用于根据系统是32位还是64位返回对应的state1字段的对应的计数和信号量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
// 64位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
// 取出state1[0]与state1[1]的所有二进制位
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
// 32位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
// 取出state1[0]与state1[1]的所有二进制位
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
// Add 方法传入一个增量 可以为赋值则代表 Done
// 例如 调用Done方法就是对Add方法传入了-1
// 即Add(-1)
func (wg *WaitGroup) Add(delta int) {
// 根据系统位数返回计数标记和信号量标记
statep, semap := wg.state()
// 做了race检查和异常的检查
if race.Enabled {
_ = *statep // 如果信号量是个空指针,则报错
if delta < 0 {
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
// 将计数标记的高32位的值+delta
// 如果是64位系统则代表state1[0]+delta
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 转换成正常的数字
// 例如 以64为系统为例 当原state1[0]为0时
// atomic.AddUint64(statep, uint64(delta)<<32)
// 当delta==1时
// 结果为 2^32-1
// 该操作就是将此数值转变为1
v := int32(state >> 32)
// w代表需要等待的数量即 64位系统中的state1[1]为state
w := uint32(state)
// 做race判断,并判断delta是否为负数 如果是的话并且v与delta相等则做一些race的同步
// fixme 此处解释存疑
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(semap))
}
// 如果v<0则代表delta的传入的为负值,并且该负值与原couter相减后小于0
// 说白了一点就说Add传入的负值超出了原有couter的数量
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 如果等待数量不是0 并且delta>0 且v==delta 则代表出现了
// 同时并发调用了Add和Wait
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 正常情况 直接返回
if v > 0 || w == 0 {
return
}
// 也是同时调用Add和Wait
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 如果计数值v为0并且waiter的数量w不为0
// 则代表delta传入的值使得couter变为了0,但是还是有waiter在等待的话
// 就把statep即state1[0]与state1[1]设置为0
// 并唤醒所有的正在等待的阻塞goroutine
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
Add(delta int)
方法的总体执行过程大概如下:
- 根据64位还是32位系统获取对应的标记位
- 对计数标记位做delta的Add
- 判断Add之后的state是否为一些非法情况,比如v<0等
- 如果v和w分别为大于0和等于0则正常返回,代表Add成功
- 否则判断当v为0时则代表没有counter了,但是还有waiter那么就把state整体设置为0,随后唤醒调用了
Wait()
的阻塞的goroutine。
Done()
方法解析
// Done 没什么可说的 就是调用了一下Add(-1)
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
Wait()
方法解析
func (wg *WaitGroup) Wait() {
// 根据系统位数返回计数标记和信号量标记
statep, semap := wg.state()
// race检测
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()
}
// 循环校验是否所有的goroutine都调用了Done
for {
//原子获取值
state := atomic.LoadUint64(statep)
// v 代表 couter数量 ,即高32位
v := int32(state >> 32)
// w 代表waiter数量 ,即低32位
w := uint32(state)
// 如果v==0 代表没有 couter了则不用等待了直接返回
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// 如果statep的值与state相等 还有需要等待完成的goroutine 此时则waiter+1
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// race检测
if race.Enabled && w == 0 {
race.Write(unsafe.Pointer(semap))
}
// 阻塞等待直到被Add唤醒
runtime_Semacquire(semap)
// 如果被唤醒了 但是发现地址中的值不是0 代表唤醒错误 panic
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
// race检测
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
Wait()
方法的总体执行过程大概如下:
- 获取标记位的地址
- 获取值
- 判断v是否为0,如果是则无须等待。
- 如果v不是0并且statep的值与state相等,则代表还有需要等待完成的goroutine,此时则waiter+1,然后阻塞等待Add方法唤醒
总结
sync.WaitGroup
使用的规范
Add(delta int)
方法可以设置为负值,但是必须要确保这个负值delta加上当前计数器的数量的结果大于0。否则会panic。🌰如下
func main() {
var wait sync.WaitGroup
wait.Add(1) // 没问题 计数器为1
wait.Add(-1)// 没问题 计数器为0
wait.Add(-3) // panic 此时计数器为0-3 出错
}
func main() {
var wait sync.WaitGroup
wait.Add(1) // 没问题 计数器为1
wait.Add(1)// 没问题 计数器为2
wait.Done() // 没问题 计数器为1
wait.Done()// 没问题 计数器为0
wait.Done()// panic 此时计数器为-1
}
2.Add(delta int)
方法必须在 Wait()
方法调用之前全部调用完毕,否则会出现panic。举个🌰,本实例的设想是进入goroutine后Add(delta int)
但是Wait()
方法调用早于goroutine中Add(delta int)
,所以此时Wait()
计数器为0,则不等待直接跳过。代码如下:
func main() {
var wait sync.WaitGroup
go func() {
// 故意sleep 代表执行逻辑
time.Sleep(1*time.Millisecond)
fmt.Println("Add")
wait.Add(1)
wait.Done()
}()
go func() {
// 故意sleep 代表执行逻辑
time.Sleep(1*time.Millisecond)
fmt.Println("Add")
wait.Add(1)
wait.Done()
}()
wait.Wait()
fmt.Println("Done")
}
3.不可以在前一个Wait()
还未结束时,复用sync.WaitGroup
,举个例子代码:
func main() {
var wait sync.WaitGroup
wait.Add(1)
go func() {
// 故意sleep 代表执行逻辑
time.Sleep(1*time.Millisecond)
wait.Done() // 正常结束
wait.Add(1) // 此处panic 因为wait还未结束就再次复用
}()
wait.Wait()
fmt.Println("Done")
}
sync.WaitGroup
虽然可以重用,但是是有一个前提的,那就是必须等到上一轮的Wait()
完成之后,才能重用sync.WaitGroup
执行下一轮的Add(delta int)
/Wait()
,如果你在Wait()
还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic。
至此sync.WaitGroup
的源码解析就解析完了,可能有些地方有些理解上的错误,请各位谅解并且帮忙指出修改意见,如果这篇文章能帮到你,这是我的荣幸。