gotosocial/vendor/codeberg.org/gruf/go-mempool/pool.go

140 lines
2.7 KiB
Go
Raw Normal View History

package mempool
import (
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/cpu"
)
// Pool provides a form of SimplePool
// with the addition of concurrency safety.
type Pool[T any] struct {
UnsafePool
// New is an optionally provided
// allocator used when no value
// is available for use in pool.
New func() T
// Reset is an optionally provided
// value resetting function called
// on passed value to Put().
Reset func(T) bool
}
func NewPool[T any](new func() T, reset func(T) bool, check func(current, victim int) bool) Pool[T] {
return Pool[T]{
New: new,
Reset: reset,
UnsafePool: NewUnsafePool(check),
}
}
func (p *Pool[T]) Get() T {
if ptr := p.UnsafePool.Get(); ptr != nil {
return *(*T)(ptr)
}
var t T
if p.New != nil {
t = p.New()
}
return t
}
func (p *Pool[T]) Put(t T) {
if p.Reset != nil && !p.Reset(t) {
return
}
ptr := unsafe.Pointer(&t)
p.UnsafePool.Put(ptr)
}
// UnsafePool provides a form of UnsafeSimplePool
// with the addition of concurrency safety.
type UnsafePool struct {
internal
_ [cache_line_size - unsafe.Sizeof(internal{})%cache_line_size]byte
}
func NewUnsafePool(check func(current, victim int) bool) UnsafePool {
return UnsafePool{internal: internal{
pool: UnsafeSimplePool{Check: check},
}}
}
const (
// current platform integer size.
int_size = 32 << (^uint(0) >> 63)
// platform CPU cache line size to avoid false sharing.
cache_line_size = unsafe.Sizeof(cpu.CacheLinePad{})
)
type internal struct {
// fast-access ring-buffer of
// pointers accessible by index.
//
// if Go ever exposes goroutine IDs
// to us we can make this a lot faster.
ring [int_size / 4]unsafe.Pointer
index atomic.Uint64
// underlying pool and
// slow mutex protection.
pool UnsafeSimplePool
mutex sync.Mutex
}
func (p *internal) Check(fn func(current, victim int) bool) func(current, victim int) bool {
p.mutex.Lock()
if fn == nil {
if p.pool.Check == nil {
fn = defaultCheck
} else {
fn = p.pool.Check
}
} else {
p.pool.Check = fn
}
p.mutex.Unlock()
return fn
}
func (p *internal) Get() unsafe.Pointer {
if ptr := atomic.SwapPointer(&p.ring[p.index.Load()%uint64(cap(p.ring))], nil); ptr != nil {
p.index.Add(^uint64(0)) // i.e. -1
return ptr
}
p.mutex.Lock()
ptr := p.pool.Get()
p.mutex.Unlock()
return ptr
}
func (p *internal) Put(ptr unsafe.Pointer) {
if atomic.CompareAndSwapPointer(&p.ring[p.index.Add(1)%uint64(cap(p.ring))], nil, ptr) {
return
}
p.mutex.Lock()
p.pool.Put(ptr)
p.mutex.Unlock()
}
func (p *internal) GC() {
for i := range p.ring {
atomic.StorePointer(&p.ring[i], nil)
}
p.mutex.Lock()
p.pool.GC()
p.mutex.Unlock()
}
func (p *internal) Size() int {
p.mutex.Lock()
sz := p.pool.Size()
p.mutex.Unlock()
return sz
}