Go中用缓冲通道作为信号量限制goroutine

当我们需要管理 有多少goroutine可以同时访问资源,使用信号量是一种可靠的方法。

可以使用缓冲通道创建一个信号量,其中通道的大小决定了可以同时运行多少个goroutine:

  • 一个goroutine发送一个值到通道中,占用一个槽。
  •  在完成任务后,它会删除该值,从而为另一个goroutine释放该插槽。

信号量 代码:

// nonbinary/counting semaphores
// 用于有多个读者和一个写者的读者-写者锁。
type Semaphore struct {
    sem     chan struct{ id int }
    timeout time.Duration
// how long to wait to acquire the semaphore before giving up
}

// 在获取时推送给 Chan,表示我们正在使用可用资源
func (s *Semaphore) semaAcquire(id int) error {
    timer := time.NewTimer(s.timeout)

    select {
    case s.sem <- struct{ id int }{id: id}:
        timer.Stop()
        return nil
    case <-timer.C:
        fmt.Println(
"deadline exceeded in acquiring the semaphore!")
        return ErrNoTickets
    }
}

Go中semaphore是一个加权的信号量。 加权信号量允许一个goroutine吃多个槽,这在任务资源消耗不同的场景中很有用。 例如,管理一个数据库连接池,其中某些操作可能需要同时使用多个连接。

下面是完整代码:

package main

import (
    "errors"
    
"fmt"
    
"time"
)

var ErrNoTickets = errors.New(
"semaWait deadline exceeded!")

// nonbinary/counting semaphores
// used in reader-writer locks where we have multiple readers and a single writer.
type Semaphore struct {
    sem     chan struct{ id int }
    timeout time.Duration
// how long to wait to acquire the semaphore before giving up
}

// push to the chan on acquire denoting we are utilising the available resource
func (s *Semaphore) semaAcquire(id int) error {
    timer := time.NewTimer(s.timeout)

    select {
    case s.sem <- struct{ id int }{id: id}:
        timer.Stop()
        return nil
    case <-timer.C:
        fmt.Println(
"deadline exceeded in acquiring the semaphore!")
        return ErrNoTickets
    }
}

// just consume from the channel
func (s *Semaphore) semaRelease() {
    ID := <-s.sem
    fmt.Printf(
"releasing the lock held by :%d\n", ID.id)
}

func (s *Semaphore) IsEmpty() bool {
    return len(s.sem) == 0
}

func semaInit(count int, timeout time.Duration) *Semaphore {
    sema := &Semaphore{
        sem:     make(chan struct{ id int }, count),
        timeout: timeout,
    }
    return sema
}

测试代码:

package main

import (
    "sync"
    
"testing"
    
"time"
)

func TestSemaDeadlineExceeded(t *testing.T) {
    
// 用 4 张可用票启动 5 个 goroutines
   
//每个 goroutines 只休眠 1 秒钟。第 5 个程序应在超过截止时间后出错
    var wg sync.WaitGroup
    n := 4
    sema := semaInit(n, 50*time.Millisecond)

    for i := 0; i < n; i++ {
        wg.Add(1)
        go func(id int) {
            
//只要有 n 张门票可用,就可以同时访问资源
           
// 例如,在多读者、单写入器中,只要没有写入器处于活动状态,就有 N 个读者可以读取资源。
            defer wg.Done()

            
// acquire the resource
            if err := sema.semaAcquire(id); err != nil {
                t.Error(err)
                return
            }

            
// do the work
            time.Sleep(2 * time.Second)

            
// release the semaphore
            sema.semaRelease()

        }(i + 1)
    }

    time.Sleep(1 * time.Second)
    
//最后一张开始得有点晚,理想情况下应该超时,因为所有门票都已消费完毕,正在进行结算。

    wg.Add(1)
    go func() {
        defer wg.Done()
        if err := sema.semaAcquire(5); err != ErrNoTickets {
            t.Error(err)
            return
        }
    }()

    wg.Wait()

}