mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 05:32:25 -05:00 
			
		
		
		
	
		
			
	
	
		
			135 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			135 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package fastcopy | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"io" | ||
|  | 	"sync" | ||
|  | 	_ "unsafe" // link to io.errInvalidWrite. | ||
|  | ) | ||
|  | 
 | ||
|  | var ( | ||
|  | 	// global pool instance. | ||
|  | 	pool = CopyPool{size: 4096} | ||
|  | 
 | ||
|  | 	//go:linkname errInvalidWrite io.errInvalidWrite | ||
|  | 	errInvalidWrite error | ||
|  | ) | ||
|  | 
 | ||
|  | // CopyPool provides a memory pool of byte | ||
|  | // buffers for io copies from readers to writers. | ||
|  | type CopyPool struct { | ||
|  | 	size int | ||
|  | 	pool sync.Pool | ||
|  | } | ||
|  | 
 | ||
|  | // See CopyPool.Buffer(). | ||
|  | func Buffer(sz int) int { | ||
|  | 	return pool.Buffer(sz) | ||
|  | } | ||
|  | 
 | ||
|  | // See CopyPool.CopyN(). | ||
|  | func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { | ||
|  | 	return pool.CopyN(dst, src, n) | ||
|  | } | ||
|  | 
 | ||
|  | // See CopyPool.Copy(). | ||
|  | func Copy(dst io.Writer, src io.Reader) (int64, error) { | ||
|  | 	return pool.Copy(dst, src) | ||
|  | } | ||
|  | 
 | ||
|  | // Buffer sets the pool buffer size to allocate. Returns current size. | ||
|  | // Note this is NOT atomically safe, please call BEFORE other calls to CopyPool. | ||
|  | func (cp *CopyPool) Buffer(sz int) int { | ||
|  | 	if sz > 0 { | ||
|  | 		// update size | ||
|  | 		cp.size = sz | ||
|  | 	} else if cp.size < 1 { | ||
|  | 		// default size | ||
|  | 		return 4096 | ||
|  | 	} | ||
|  | 	return cp.size | ||
|  | } | ||
|  | 
 | ||
|  | // CopyN performs the same logic as io.CopyN(), with the difference | ||
|  | // being that the byte buffer is acquired from a memory pool. | ||
|  | func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) { | ||
|  | 	written, err := cp.Copy(dst, io.LimitReader(src, n)) | ||
|  | 	if written == n { | ||
|  | 		return n, nil | ||
|  | 	} | ||
|  | 	if written < n && err == nil { | ||
|  | 		// src stopped early; must have been EOF. | ||
|  | 		err = io.EOF | ||
|  | 	} | ||
|  | 	return written, err | ||
|  | } | ||
|  | 
 | ||
|  | // Copy performs the same logic as io.Copy(), with the difference | ||
|  | // being that the byte buffer is acquired from a memory pool. | ||
|  | func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) { | ||
|  | 	// Prefer using io.WriterTo to do the copy (avoids alloc + copy) | ||
|  | 	if wt, ok := src.(io.WriterTo); ok { | ||
|  | 		return wt.WriteTo(dst) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// Prefer using io.ReaderFrom to do the copy. | ||
|  | 	if rt, ok := dst.(io.ReaderFrom); ok { | ||
|  | 		return rt.ReadFrom(src) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	var buf []byte | ||
|  | 
 | ||
|  | 	if b, ok := cp.pool.Get().([]byte); ok { | ||
|  | 		// Acquired buf from pool | ||
|  | 		buf = b | ||
|  | 	} else { | ||
|  | 		// Allocate new buffer of size | ||
|  | 		buf = make([]byte, cp.Buffer(0)) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// Defer release to pool | ||
|  | 	defer cp.pool.Put(buf) | ||
|  | 
 | ||
|  | 	var n int64 | ||
|  | 	for { | ||
|  | 		// Perform next read into buf | ||
|  | 		nr, err := src.Read(buf) | ||
|  | 		if nr > 0 { | ||
|  | 			// We error check AFTER checking | ||
|  | 			// no. read bytes so incomplete | ||
|  | 			// read still gets written up to nr. | ||
|  | 
 | ||
|  | 			// Perform next write from buf | ||
|  | 			nw, ew := dst.Write(buf[0:nr]) | ||
|  | 
 | ||
|  | 			// Check for valid write | ||
|  | 			if nw < 0 || nr < nw { | ||
|  | 				if ew == nil { | ||
|  | 					ew = errInvalidWrite | ||
|  | 				} | ||
|  | 				return n, ew | ||
|  | 			} | ||
|  | 
 | ||
|  | 			// Incr total count | ||
|  | 			n += int64(nw) | ||
|  | 
 | ||
|  | 			// Check write error | ||
|  | 			if ew != nil { | ||
|  | 				return n, ew | ||
|  | 			} | ||
|  | 
 | ||
|  | 			// Check unequal read/writes | ||
|  | 			if nr != nw { | ||
|  | 				return n, io.ErrShortWrite | ||
|  | 			} | ||
|  | 		} | ||
|  | 
 | ||
|  | 		// Return on err | ||
|  | 		if err != nil { | ||
|  | 			if err == io.EOF { | ||
|  | 				err = nil // expected | ||
|  | 			} | ||
|  | 			return n, err | ||
|  | 		} | ||
|  | 	} | ||
|  | } |