This commit is contained in:
tsmethurst 2021-08-24 16:54:54 +02:00
commit 526a14a92d
486 changed files with 84353 additions and 23865 deletions

View file

@ -1,24 +0,0 @@
Copyright (c) 2021 Vladimir Mihailenco. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -1,36 +0,0 @@
# pgdriver
[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun/driver/pgdriver)](https://pkg.go.dev/github.com/uptrace/bun/driver/pgdriver)
pgdriver is a database/sql driver for PostgreSQL based on [go-pg](https://github.com/go-pg/pg) code.
You can install it with:
```shell
github.com/uptrace/bun/driver/pgdriver
```
And then create a `sql.DB` using it:
```go
import _ "github.com/uptrace/bun/driver/pgdriver"
dsn := "postgres://postgres:@localhost:5432/test"
db, err := sql.Open("pg", dsn)
```
Alternatively:
```go
dsn := "postgres://postgres:@localhost:5432/test"
db := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
```
[Benchmark](https://github.com/go-bun/bun-benchmark):
```
BenchmarkInsert/pg-12 7254 148380 ns/op 900 B/op 13 allocs/op
BenchmarkInsert/pgx-12 6494 166391 ns/op 2076 B/op 26 allocs/op
BenchmarkSelect/pg-12 9100 132952 ns/op 1417 B/op 18 allocs/op
BenchmarkSelect/pgx-12 8199 154920 ns/op 3679 B/op 60 allocs/op
```

View file

@ -1,192 +0,0 @@
package pgdriver
import (
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
"time"
)
const (
pgBool = 16
pgInt2 = 21
pgInt4 = 23
pgInt8 = 20
pgFloat4 = 700
pgFloat8 = 701
pgText = 25
pgVarchar = 1043
pgBytea = 17
pgDate = 1082
pgTimestamp = 1114
pgTimestamptz = 1184
)
func readColumnValue(rd *reader, dataType int32, dataLen int) (interface{}, error) {
if dataLen == -1 {
return nil, nil
}
switch dataType {
case pgBool:
return readBoolCol(rd, dataLen)
case pgInt2:
return readIntCol(rd, dataLen, 16)
case pgInt4:
return readIntCol(rd, dataLen, 32)
case pgInt8:
return readIntCol(rd, dataLen, 64)
case pgFloat4:
return readFloatCol(rd, dataLen, 32)
case pgFloat8:
return readFloatCol(rd, dataLen, 64)
case pgTimestamp:
return readTimeCol(rd, dataLen)
case pgTimestamptz:
return readTimeCol(rd, dataLen)
case pgDate:
return readTimeCol(rd, dataLen)
case pgText, pgVarchar:
return readStringCol(rd, dataLen)
case pgBytea:
return readBytesCol(rd, dataLen)
}
b := make([]byte, dataLen)
if _, err := io.ReadFull(rd, b); err != nil {
return nil, err
}
return b, nil
}
func readBoolCol(rd *reader, n int) (interface{}, error) {
tmp, err := rd.ReadTemp(n)
if err != nil {
return nil, err
}
return len(tmp) == 1 && (tmp[0] == 't' || tmp[0] == '1'), nil
}
func readIntCol(rd *reader, n int, bitSize int) (interface{}, error) {
if n <= 0 {
return 0, nil
}
tmp, err := rd.ReadTemp(n)
if err != nil {
return 0, err
}
return strconv.ParseInt(bytesToString(tmp), 10, bitSize)
}
func readFloatCol(rd *reader, n int, bitSize int) (interface{}, error) {
if n <= 0 {
return 0, nil
}
tmp, err := rd.ReadTemp(n)
if err != nil {
return 0, err
}
return strconv.ParseFloat(bytesToString(tmp), bitSize)
}
func readStringCol(rd *reader, n int) (interface{}, error) {
if n <= 0 {
return "", nil
}
b := make([]byte, n)
if _, err := io.ReadFull(rd, b); err != nil {
return nil, err
}
return bytesToString(b), nil
}
func readBytesCol(rd *reader, n int) (interface{}, error) {
if n <= 0 {
return []byte{}, nil
}
tmp, err := rd.ReadTemp(n)
if err != nil {
return nil, err
}
if len(tmp) < 2 || tmp[0] != '\\' || tmp[1] != 'x' {
return nil, fmt.Errorf("pgdriver: can't parse bytea: %q", tmp)
}
tmp = tmp[2:] // Cut off "\x".
b := make([]byte, hex.DecodedLen(len(tmp)))
if _, err := hex.Decode(b, tmp); err != nil {
return nil, err
}
return b, nil
}
func readTimeCol(rd *reader, n int) (interface{}, error) {
if n <= 0 {
return time.Time{}, nil
}
tmp, err := rd.ReadTemp(n)
if err != nil {
return time.Time{}, err
}
tm, err := parseTime(bytesToString(tmp))
if err != nil {
return time.Time{}, err
}
return tm, nil
}
const (
dateFormat = "2006-01-02"
timeFormat = "15:04:05.999999999"
timestampFormat = "2006-01-02 15:04:05.999999999"
timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00"
timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00"
timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07"
)
func parseTime(s string) (time.Time, error) {
switch l := len(s); {
case l < len("15:04:05"):
return time.Time{}, fmt.Errorf("pgdriver: can't parse time=%q", s)
case l <= len(timeFormat):
if s[2] == ':' {
return time.ParseInLocation(timeFormat, s, time.UTC)
}
return time.ParseInLocation(dateFormat, s, time.UTC)
default:
if s[10] == 'T' {
return time.Parse(time.RFC3339Nano, s)
}
if c := s[l-9]; c == '+' || c == '-' {
return time.Parse(timestamptzFormat, s)
}
if c := s[l-6]; c == '+' || c == '-' {
return time.Parse(timestamptzFormat2, s)
}
if c := s[l-3]; c == '+' || c == '-' {
if strings.HasSuffix(s, "+00") {
s = s[:len(s)-3]
return time.ParseInLocation(timestampFormat, s, time.UTC)
}
return time.Parse(timestamptzFormat3, s)
}
return time.ParseInLocation(timestampFormat, s, time.UTC)
}
}

View file

@ -1,233 +0,0 @@
package pgdriver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"os"
"strings"
"time"
)
type Config struct {
// Network type, either tcp or unix.
// Default is tcp.
Network string
// TCP host:port or Unix socket depending on Network.
Addr string
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// TLS config for secure connections.
TLSConfig *tls.Config
User string
Password string
Database string
AppName string
// Timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking.
WriteTimeout time.Duration
}
func newDefaultConfig() *Config {
host := env("PGHOST", "localhost")
port := env("PGPORT", "5432")
cfg := &Config{
Network: "tcp",
Addr: net.JoinHostPort(host, port),
DialTimeout: 5 * time.Second,
User: env("PGUSER", "postgres"),
Database: env("PGDATABASE", "postgres"),
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
}
cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: cfg.DialTimeout,
KeepAlive: 5 * time.Minute,
}
return netDialer.DialContext(ctx, network, addr)
}
return cfg
}
type DriverOption func(*Connector)
func WithAddr(addr string) DriverOption {
if addr == "" {
panic("addr is empty")
}
return func(d *Connector) {
d.cfg.Addr = addr
}
}
func WithTLSConfig(cfg *tls.Config) DriverOption {
return func(d *Connector) {
d.cfg.TLSConfig = cfg
}
}
func WithUser(user string) DriverOption {
if user == "" {
panic("user is empty")
}
return func(d *Connector) {
d.cfg.User = user
}
}
func WithPassword(password string) DriverOption {
return func(d *Connector) {
d.cfg.Password = password
}
}
func WithDatabase(database string) DriverOption {
if database == "" {
panic("database is empty")
}
return func(d *Connector) {
d.cfg.Database = database
}
}
func WithApplicationName(appName string) DriverOption {
return func(d *Connector) {
d.cfg.AppName = appName
}
}
func WithTimeout(timeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.DialTimeout = timeout
d.cfg.ReadTimeout = timeout
d.cfg.WriteTimeout = timeout
}
}
func WithDialTimeout(dialTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.DialTimeout = dialTimeout
}
}
func WithReadTimeout(readTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.ReadTimeout = readTimeout
}
}
func WithWriteTimeout(writeTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.WriteTimeout = writeTimeout
}
}
func WithDSN(dsn string) DriverOption {
return func(d *Connector) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(d)
}
}
}
func parseDSN(dsn string) ([]DriverOption, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
}
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
}
query, err := url.ParseQuery(u.RawQuery)
if err != nil {
return nil, err
}
var opts []DriverOption
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}
if u.User != nil {
opts = append(opts, WithUser(u.User.Username()))
if password, ok := u.User.Password(); ok {
opts = append(opts, WithPassword(password))
}
}
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}
if appName := query.Get("application_name"); appName != "" {
opts = append(opts, WithApplicationName(appName))
}
delete(query, "application_name")
if sslMode := query.Get("sslmode"); sslMode != "" {
switch sslMode {
case "verify-ca", "verify-full":
opts = append(opts, WithTLSConfig(new(tls.Config)))
case "allow", "prefer", "require":
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
case "disable":
// no TLS config
default:
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
}
} else {
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
}
delete(query, "sslmode")
for key := range query {
return nil, fmt.Errorf("pgdriver: unsupported option=%q", key)
}
return opts, nil
}
func env(key, defValue string) string {
if s := os.Getenv(key); s != "" {
return s
}
return defValue
}
// verify is a method to make sure if the config is legitimate
// in the case it detects any errors, it returns with a non-nil error
// it can be extended to check other parameters
func (c *Config) verify() error {
if c.User == "" {
return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
}
return nil
}

View file

@ -1,606 +0,0 @@
package pgdriver
import (
"bufio"
"bytes"
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
)
func init() {
sql.Register("pg", NewDriver())
}
type logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
var Logger logging = &logger{
log: log.New(os.Stderr, "pgdriver: ", log.LstdFlags|log.Lshortfile),
}
//------------------------------------------------------------------------------
type Driver struct {
connector *Connector
}
var _ driver.DriverContext = (*Driver)(nil)
func NewDriver() Driver {
return Driver{}
}
func (d Driver) OpenConnector(name string) (driver.Connector, error) {
opts, err := parseDSN(name)
if err != nil {
return nil, err
}
return NewConnector(opts...), nil
}
func (d Driver) Open(name string) (driver.Conn, error) {
connector, err := d.OpenConnector(name)
if err != nil {
return nil, err
}
return connector.Connect(context.TODO())
}
//------------------------------------------------------------------------------
type DriverStats struct {
Queries uint64
Errors uint64
}
type Connector struct {
cfg *Config
stats DriverStats
}
func NewConnector(opts ...DriverOption) *Connector {
d := &Connector{cfg: newDefaultConfig()}
for _, opt := range opts {
opt(d)
}
return d
}
var _ driver.Connector = (*Connector)(nil)
func (d *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if err := d.cfg.verify(); err != nil {
return nil, err
}
return newConn(ctx, d)
}
func (d *Connector) Driver() driver.Driver {
return Driver{connector: d}
}
func (d *Connector) Config() *Config {
return d.cfg
}
func (d *Connector) Stats() DriverStats {
return DriverStats{
Queries: atomic.LoadUint64(&d.stats.Queries),
Errors: atomic.LoadUint64(&d.stats.Errors),
}
}
//------------------------------------------------------------------------------
type Conn struct {
driver *Connector
netConn net.Conn
rd *reader
processID int32
secretKey int32
stmtCount int
closed int32
}
func newConn(ctx context.Context, driver *Connector) (*Conn, error) {
netConn, err := driver.cfg.Dialer(ctx, driver.cfg.Network, driver.cfg.Addr)
if err != nil {
return nil, err
}
cn := &Conn{
driver: driver,
netConn: netConn,
rd: newReader(netConn),
}
if cn.driver.cfg.TLSConfig != nil {
if err := enableSSL(ctx, cn, cn.driver.cfg.TLSConfig); err != nil {
return nil, err
}
}
if err := startup(ctx, cn); err != nil {
return nil, err
}
return cn, nil
}
func (cn *Conn) reader(ctx context.Context, timeout time.Duration) *reader {
cn.setReadDeadline(ctx, timeout)
return cn.rd
}
func (cn *Conn) withWriter(
ctx context.Context,
timeout time.Duration,
fn func(wr *bufio.Writer) error,
) error {
wr := getBufioWriter()
cn.setWriteDeadline(ctx, timeout)
wr.Reset(cn.netConn)
err := fn(wr)
if err == nil {
err = wr.Flush()
}
putBufioWriter(wr)
return err
}
var _ driver.Conn = (*Conn)(nil)
func (cn *Conn) Prepare(query string) (driver.Stmt, error) {
if cn.isClosed() {
return nil, driver.ErrBadConn
}
ctx := context.TODO()
name := fmt.Sprintf("pgdriver-%d", cn.stmtCount)
cn.stmtCount++
if err := writeParseDescribeSync(ctx, cn, name, query); err != nil {
return nil, err
}
rowDesc, err := readParseDescribeSync(ctx, cn)
if err != nil {
return nil, err
}
return newStmt(cn, name, rowDesc), nil
}
func (cn *Conn) Close() error {
if !atomic.CompareAndSwapInt32(&cn.closed, 0, 1) {
return nil
}
return cn.netConn.Close()
}
func (cn *Conn) isClosed() bool {
return atomic.LoadInt32(&cn.closed) == 1
}
func (cn *Conn) Begin() (driver.Tx, error) {
return cn.BeginTx(context.Background(), driver.TxOptions{})
}
var _ driver.ConnBeginTx = (*Conn)(nil)
func (cn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// No need to check if the conn is closed. ExecContext below handles that.
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
return nil, errors.New("pgdriver: custom IsolationLevel is not supported")
}
if opts.ReadOnly {
return nil, errors.New("pgdriver: ReadOnly transactions are not supported")
}
if _, err := cn.ExecContext(ctx, "BEGIN", nil); err != nil {
return nil, err
}
return tx{cn: cn}, nil
}
var _ driver.ExecerContext = (*Conn)(nil)
func (cn *Conn) ExecContext(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Result, error) {
if cn.isClosed() {
return nil, driver.ErrBadConn
}
res, err := cn.exec(ctx, query, args)
if err != nil {
return nil, cn.checkBadConn(err)
}
return res, nil
}
func (cn *Conn) exec(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Result, error) {
query, err := formatQuery(query, args)
if err != nil {
return nil, err
}
if err := writeQuery(ctx, cn, query); err != nil {
return nil, err
}
return readQuery(ctx, cn)
}
var _ driver.QueryerContext = (*Conn)(nil)
func (cn *Conn) QueryContext(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Rows, error) {
if cn.isClosed() {
return nil, driver.ErrBadConn
}
rows, err := cn.query(ctx, query, args)
if err != nil {
return nil, cn.checkBadConn(err)
}
return rows, nil
}
func (cn *Conn) query(
ctx context.Context, query string, args []driver.NamedValue,
) (driver.Rows, error) {
query, err := formatQuery(query, args)
if err != nil {
return nil, err
}
if err := writeQuery(ctx, cn, query); err != nil {
return nil, err
}
return readQueryData(ctx, cn)
}
var _ driver.Pinger = (*Conn)(nil)
func (cn *Conn) Ping(ctx context.Context) error {
_, err := cn.ExecContext(ctx, "SELECT 1", nil)
return err
}
func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.driver.cfg.ReadTimeout
}
_ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
}
func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.driver.cfg.WriteTimeout
}
_ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
deadline, ok := ctx.Deadline()
if !ok {
if timeout == 0 {
return time.Time{}
}
return time.Now().Add(timeout)
}
if timeout == 0 {
return deadline
}
if tm := time.Now().Add(timeout); tm.Before(deadline) {
return tm
}
return deadline
}
var _ driver.Validator = (*Conn)(nil)
func (cn *Conn) IsValid() bool {
return !cn.isClosed()
}
func (cn *Conn) checkBadConn(err error) error {
if isBadConn(err, false) {
// Close and return driver.ErrBadConn next time the conn is used.
_ = cn.Close()
}
// Always return the original error.
return err
}
//------------------------------------------------------------------------------
type rows struct {
cn *Conn
rowDesc *rowDescription
reusable bool
closed bool
}
var _ driver.Rows = (*rows)(nil)
func newRows(cn *Conn, rowDesc *rowDescription, reusable bool) *rows {
return &rows{
cn: cn,
rowDesc: rowDesc,
reusable: reusable,
}
}
func (r *rows) Columns() []string {
if r.closed || r.rowDesc == nil {
return nil
}
return r.rowDesc.names
}
func (r *rows) Close() error {
if r.closed {
return nil
}
defer r.close()
for {
switch err := r.Next(nil); err {
case nil, io.EOF:
return nil
default: // unexpected error
_ = r.cn.Close()
return err
}
}
}
func (r *rows) close() {
r.closed = true
if r.rowDesc != nil {
if r.reusable {
rowDescPool.Put(r.rowDesc)
}
r.rowDesc = nil
}
}
func (r *rows) Next(dest []driver.Value) error {
if r.closed {
return io.EOF
}
eof, err := r.next(dest)
if err == io.EOF {
return io.ErrUnexpectedEOF
} else if err != nil {
return err
}
if eof {
return io.EOF
}
return nil
}
func (r *rows) next(dest []driver.Value) (eof bool, _ error) {
rd := r.cn.reader(context.TODO(), -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return false, err
}
switch c {
case dataRowMsg:
return false, r.readDataRow(rd, dest)
case commandCompleteMsg:
if err := rd.Discard(msgLen); err != nil {
return false, err
}
case readyForQueryMsg:
r.close()
if err := rd.Discard(msgLen); err != nil {
return false, err
}
if firstErr != nil {
return false, firstErr
}
return true, nil
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return false, err
}
if firstErr == nil {
firstErr = e
}
default:
return false, fmt.Errorf("pgdriver: Next: unexpected message %q", c)
}
}
}
func (r *rows) readDataRow(rd *reader, dest []driver.Value) error {
numCol, err := readInt16(rd)
if err != nil {
return err
}
if len(dest) != int(numCol) {
return fmt.Errorf("pgdriver: query returned %d columns, but Scan dest has %d items",
numCol, len(dest))
}
for colIdx := int16(0); colIdx < numCol; colIdx++ {
dataLen, err := readInt32(rd)
if err != nil {
return err
}
value, err := readColumnValue(rd, r.rowDesc.types[colIdx], int(dataLen))
if err != nil {
return err
}
if dest != nil {
dest[colIdx] = value
}
}
return nil
}
//------------------------------------------------------------------------------
func parseResult(b []byte) (driver.RowsAffected, error) {
i := bytes.LastIndexByte(b, ' ')
if i == -1 {
return 0, nil
}
b = b[i+1 : len(b)-1]
affected, err := strconv.ParseUint(bytesToString(b), 10, 64)
if err != nil {
return 0, nil
}
return driver.RowsAffected(affected), nil
}
//------------------------------------------------------------------------------
type tx struct {
cn *Conn
}
var _ driver.Tx = (*tx)(nil)
func (tx tx) Commit() error {
_, err := tx.cn.ExecContext(context.Background(), "COMMIT", nil)
return err
}
func (tx tx) Rollback() error {
_, err := tx.cn.ExecContext(context.Background(), "ROLLBACK", nil)
return err
}
//------------------------------------------------------------------------------
type stmt struct {
cn *Conn
name string
rowDesc *rowDescription
}
var (
_ driver.Stmt = (*stmt)(nil)
_ driver.StmtExecContext = (*stmt)(nil)
_ driver.StmtQueryContext = (*stmt)(nil)
)
func newStmt(cn *Conn, name string, rowDesc *rowDescription) *stmt {
return &stmt{
cn: cn,
name: name,
rowDesc: rowDesc,
}
}
func (stmt *stmt) Close() error {
if stmt.rowDesc != nil {
rowDescPool.Put(stmt.rowDesc)
stmt.rowDesc = nil
}
ctx := context.TODO()
if err := writeCloseStmt(ctx, stmt.cn, stmt.name); err != nil {
return err
}
if err := readCloseStmtComplete(ctx, stmt.cn); err != nil {
return err
}
return nil
}
func (stmt *stmt) NumInput() int {
if stmt.rowDesc == nil {
return -1
}
return int(stmt.rowDesc.numInput)
}
func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) {
panic("not implemented")
}
func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil {
return nil, err
}
return readExtQuery(ctx, stmt.cn)
}
func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
panic("not implemented")
}
func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if err := writeBindExecute(ctx, stmt.cn, stmt.name, args); err != nil {
return nil, err
}
return readExtQueryData(ctx, stmt.cn, stmt.rowDesc)
}
//------------------------------------------------------------------------------
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriter(nil)
},
}
func getBufioWriter() *bufio.Writer {
return bufioWriterPool.Get().(*bufio.Writer)
}
func putBufioWriter(wr *bufio.Writer) {
bufioWriterPool.Put(wr)
}

View file

@ -1,66 +0,0 @@
package pgdriver
import (
"fmt"
"net"
)
// Error represents an error returned by PostgreSQL server
// using PostgreSQL ErrorResponse protocol.
//
// https://www.postgresql.org/docs/current/static/protocol-message-formats.html
type Error struct {
m map[byte]string
}
// Field returns a string value associated with an error field.
//
// https://www.postgresql.org/docs/current/static/protocol-error-fields.html
func (err Error) Field(k byte) string {
return err.m[k]
}
// IntegrityViolation reports whether an error is a part of
// Integrity Constraint Violation class of errors.
//
// https://www.postgresql.org/docs/current/static/errcodes-appendix.html
func (err Error) IntegrityViolation() bool {
switch err.Field('C') {
case "23000", "23001", "23502", "23503", "23505", "23514", "23P01":
return true
default:
return false
}
}
func (err Error) Error() string {
return fmt.Sprintf("%s #%s %s",
err.Field('S'), err.Field('C'), err.Field('M'))
}
func isBadConn(err error, allowTimeout bool) bool {
if err == nil {
return false
}
if err, ok := err.(Error); ok {
switch err.Field('V') {
case "FATAL", "PANIC":
return true
}
switch err.Field('C') {
case "25P02", // current transaction is aborted
"57014": // canceling statement due to user request
return true
}
return false
}
if allowTimeout {
if err, ok := err.(net.Error); ok && err.Timeout() {
return !err.Temporary()
}
}
return true
}

View file

@ -1,188 +0,0 @@
package pgdriver
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"math"
"strconv"
"time"
"unicode/utf8"
)
func formatQuery(query string, args []driver.NamedValue) (string, error) {
if len(args) == 0 {
return query, nil
}
dst := make([]byte, 0, 2*len(query))
p := newParser(query)
for p.Valid() {
switch c := p.Next(); c {
case '$':
if i, ok := p.Number(); ok {
if i > len(args) {
return "", fmt.Errorf("pgdriver: got %d args, wanted %d", len(args), i)
}
var err error
dst, err = appendArg(dst, args[i-1].Value)
if err != nil {
return "", err
}
} else {
dst = append(dst, '$')
}
case '\'':
if b, ok := p.QuotedString(); ok {
dst = append(dst, b...)
} else {
dst = append(dst, '\'')
}
default:
dst = append(dst, c)
}
}
return bytesToString(dst), nil
}
func appendArg(b []byte, v interface{}) ([]byte, error) {
switch v := v.(type) {
case nil:
return append(b, "NULL"...), nil
case int64:
return strconv.AppendInt(b, v, 10), nil
case float64:
switch {
case math.IsNaN(v):
return append(b, "'NaN'"...), nil
case math.IsInf(v, 1):
return append(b, "'Infinity'"...), nil
case math.IsInf(v, -1):
return append(b, "'-Infinity'"...), nil
default:
return strconv.AppendFloat(b, v, 'f', -1, 64), nil
}
case bool:
if v {
return append(b, "TRUE"...), nil
}
return append(b, "FALSE"...), nil
case []byte:
if v == nil {
return append(b, "NULL"...), nil
}
b = append(b, `'\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(v)))...)
hex.Encode(b[s:], v)
b = append(b, "'"...)
return b, nil
case string:
b = append(b, '\'')
for _, r := range v {
if r == '\000' {
continue
}
if r == '\'' {
b = append(b, '\'', '\'')
continue
}
if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b, nil
case time.Time:
if v.IsZero() {
return append(b, "NULL"...), nil
}
return v.UTC().AppendFormat(b, "'2006-01-02 15:04:05.999999-07:00'"), nil
default:
return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v)
}
}
type parser struct {
b []byte
i int
}
func newParser(s string) *parser {
return &parser{
b: stringToBytes(s),
}
}
func (p *parser) Valid() bool {
return p.i < len(p.b)
}
func (p *parser) Next() byte {
c := p.b[p.i]
p.i++
return c
}
func (p *parser) Number() (int, bool) {
start := p.i
end := len(p.b)
for i := p.i; i < len(p.b); i++ {
c := p.b[i]
if !isNum(c) {
end = i
break
}
}
p.i = end
b := p.b[start:end]
n, err := strconv.Atoi(bytesToString(b))
if err != nil {
return 0, false
}
return n, true
}
func (p *parser) QuotedString() ([]byte, bool) {
start := p.i - 1
end := len(p.b)
var c byte
for i := p.i; i < len(p.b); i++ {
next := p.b[i]
if c == '\'' && next != '\'' {
end = i
break
}
c = next
}
p.i = end
b := p.b[start:end]
return b, true
}
func isNum(c byte) bool {
return c >= '0' && c <= '9'
}

View file

@ -1,11 +0,0 @@
module github.com/uptrace/bun/driver/pgdriver
go 1.16
replace github.com/uptrace/bun => ../..
require (
github.com/stretchr/testify v1.7.0
github.com/uptrace/bun v0.4.3
mellium.im/sasl v0.2.1
)

View file

@ -1,27 +0,0 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/vmihailenco/msgpack/v5 v5.3.4 h1:qMKAwOV+meBw2Y8k9cVwAy7qErtYCwBzZ2ellBfvnqc=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b h1:2b9XGzhjiYsYPnKXoEfL7klWZQIt8IfyRCz62gCqqlQ=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=

View file

@ -1,392 +0,0 @@
package pgdriver
import (
"context"
"errors"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/uptrace/bun"
)
const pingChannel = "bun:ping"
var (
errListenerClosed = errors.New("bun: listener is closed")
errPingTimeout = errors.New("bun: ping timeout")
)
type Listener struct {
db *bun.DB
driver *Connector
channels []string
mu sync.Mutex
cn *Conn
closed bool
exit chan struct{}
}
func NewListener(db *bun.DB) *Listener {
return &Listener{
db: db,
driver: db.Driver().(Driver).connector,
exit: make(chan struct{}),
}
}
// Close closes the listener, releasing any open resources.
func (ln *Listener) Close() error {
return ln.withLock(func() error {
if ln.closed {
return errListenerClosed
}
ln.closed = true
close(ln.exit)
return ln.closeConn(errListenerClosed)
})
}
func (ln *Listener) withLock(fn func() error) error {
ln.mu.Lock()
defer ln.mu.Unlock()
return fn()
}
func (ln *Listener) conn(ctx context.Context) (*Conn, error) {
if ln.closed {
return nil, errListenerClosed
}
if ln.cn != nil {
return ln.cn, nil
}
atomic.AddUint64(&ln.driver.stats.Queries, 1)
cn, err := ln._conn(ctx)
if err != nil {
atomic.AddUint64(&ln.driver.stats.Errors, 1)
return nil, err
}
ln.cn = cn
return cn, nil
}
func (ln *Listener) _conn(ctx context.Context) (*Conn, error) {
driverConn, err := ln.driver.Connect(ctx)
if err != nil {
return nil, err
}
cn := driverConn.(*Conn)
if len(ln.channels) > 0 {
err := ln.listen(ctx, cn, ln.channels...)
if err != nil {
_ = cn.Close()
return nil, err
}
}
return cn, nil
}
func (ln *Listener) checkConn(ctx context.Context, cn *Conn, err error, allowTimeout bool) {
_ = ln.withLock(func() error {
if ln.closed || ln.cn != cn {
return nil
}
if isBadConn(err, allowTimeout) {
ln.reconnect(ctx, err)
}
return nil
})
}
func (ln *Listener) reconnect(ctx context.Context, reason error) {
if ln.cn != nil {
Logger.Printf(ctx, "bun: discarding bad listener connection: %s", reason)
_ = ln.closeConn(reason)
}
_, _ = ln.conn(ctx)
}
func (ln *Listener) closeConn(reason error) error {
if ln.cn == nil {
return nil
}
err := ln.cn.Close()
ln.cn = nil
return err
}
// Listen starts listening for notifications on channels.
func (ln *Listener) Listen(ctx context.Context, channels ...string) error {
var cn *Conn
if err := ln.withLock(func() error {
ln.channels = appendIfNotExists(ln.channels, channels...)
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return err
}
if err := ln.listen(ctx, cn, channels...); err != nil {
ln.checkConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) listen(ctx context.Context, cn *Conn, channels ...string) error {
for _, channel := range channels {
if err := writeQuery(ctx, cn, "LISTEN "+strconv.Quote(channel)); err != nil {
return err
}
}
return nil
}
// Unlisten stops listening for notifications on channels.
func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error {
var cn *Conn
if err := ln.withLock(func() error {
ln.channels = removeIfExists(ln.channels, channels...)
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return err
}
if err := ln.unlisten(ctx, cn, channels...); err != nil {
ln.checkConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) unlisten(ctx context.Context, cn *Conn, channels ...string) error {
for _, channel := range channels {
if err := writeQuery(ctx, cn, "UNLISTEN "+strconv.Quote(channel)); err != nil {
return err
}
}
return nil
}
// Receive indefinitely waits for a notification. This is low-level API
// and in most cases Channel should be used instead.
func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) {
return ln.ReceiveTimeout(ctx, 0)
}
// ReceiveTimeout waits for a notification until timeout is reached.
// This is low-level API and in most cases Channel should be used instead.
func (ln *Listener) ReceiveTimeout(
ctx context.Context, timeout time.Duration,
) (channel, payload string, err error) {
var cn *Conn
if err := ln.withLock(func() error {
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return "", "", err
}
rd := cn.reader(ctx, timeout)
channel, payload, err = readNotification(ctx, rd)
if err != nil {
ln.checkConn(ctx, cn, err, timeout > 0)
return "", "", err
}
return channel, payload, nil
}
// Channel returns a channel for concurrently receiving notifications.
// It periodically sends Ping notification to test connection health.
//
// The channel is closed with Listener. Receive* APIs can not be used
// after channel is created.
func (ln *Listener) Channel(opts ...ChannelOption) <-chan Notification {
return newChannel(ln, opts).ch
}
//------------------------------------------------------------------------------
// Notification received with LISTEN command.
type Notification struct {
Channel string
Payload string
}
type ChannelOption func(c *channel)
func WithChannelSize(size int) ChannelOption {
return func(c *channel) {
c.size = size
}
}
type channel struct {
ctx context.Context
ln *Listener
size int
pingTimeout time.Duration
chanSendTimeout time.Duration
ch chan Notification
pingCh chan struct{}
}
func newChannel(ln *Listener, opts []ChannelOption) *channel {
c := &channel{
ctx: context.TODO(),
ln: ln,
size: 100,
pingTimeout: 5 * time.Second,
chanSendTimeout: time.Minute,
}
for _, opt := range opts {
opt(c)
}
c.ch = make(chan Notification, c.size)
c.pingCh = make(chan struct{}, 1)
_ = c.ln.Listen(c.ctx, pingChannel)
go c.startReceive()
go c.startPing()
return c
}
func (c *channel) startReceive() {
timer := time.NewTimer(time.Minute)
timer.Stop()
var errCount int
for {
channel, payload, err := c.ln.Receive(c.ctx)
if err != nil {
if err == errListenerClosed {
close(c.ch)
return
}
if errCount > 0 {
time.Sleep(500 * time.Millisecond)
}
errCount++
continue
}
errCount = 0
// Any notification is as good as a ping.
select {
case c.pingCh <- struct{}{}:
default:
}
switch channel {
case pingChannel:
// ignore
default:
timer.Reset(c.chanSendTimeout)
select {
case c.ch <- Notification{channel, payload}:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
Logger.Printf(
c.ctx,
"pgdriver: %s channel is full for %s (notification is dropped)",
c,
c.chanSendTimeout,
)
}
}
}
}
func (c *channel) startPing() {
timer := time.NewTimer(time.Minute)
timer.Stop()
healthy := true
for {
timer.Reset(c.pingTimeout)
select {
case <-c.pingCh:
healthy = true
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
pingErr := c.ping(c.ctx)
if healthy {
healthy = false
} else {
if pingErr == nil {
pingErr = errPingTimeout
}
_ = c.ln.withLock(func() error {
c.ln.reconnect(c.ctx, pingErr)
return nil
})
}
case <-c.ln.exit:
return
}
}
}
func (c *channel) ping(ctx context.Context) error {
_, err := c.ln.db.ExecContext(ctx, "NOTIFY "+strconv.Quote(pingChannel))
return err
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}
func removeIfExists(ss []string, es ...string) []string {
for _, e := range es {
for i, s := range ss {
if s == e {
last := len(ss) - 1
ss[i] = ss[last]
ss = ss[:last]
break
}
}
}
return ss
}

File diff suppressed because it is too large Load diff

View file

@ -1,11 +0,0 @@
// +build appengine
package internal
func bytesToString(b []byte) string {
return string(b)
}
func stringToBytes(s string) []byte {
return []byte(s)
}

View file

@ -1,19 +0,0 @@
// +build !appengine
package pgdriver
import "unsafe"
func bytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
//nolint:deadcode,unused
func stringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

View file

@ -1,112 +0,0 @@
package pgdriver
import (
"encoding/binary"
"io"
"sync"
)
var wbPool = sync.Pool{
New: func() interface{} {
return newWriteBuffer()
},
}
func getWriteBuffer() *writeBuffer {
wb := wbPool.Get().(*writeBuffer)
return wb
}
func putWriteBuffer(wb *writeBuffer) {
wb.Reset()
wbPool.Put(wb)
}
type writeBuffer struct {
Bytes []byte
msgStart int
paramStart int
}
func newWriteBuffer() *writeBuffer {
return &writeBuffer{
Bytes: make([]byte, 0, 1024),
}
}
func (b *writeBuffer) Reset() {
b.Bytes = b.Bytes[:0]
}
func (b *writeBuffer) StartMessage(c byte) {
if c == 0 {
b.msgStart = len(b.Bytes)
b.Bytes = append(b.Bytes, 0, 0, 0, 0)
} else {
b.msgStart = len(b.Bytes) + 1
b.Bytes = append(b.Bytes, c, 0, 0, 0, 0)
}
}
func (b *writeBuffer) FinishMessage() {
binary.BigEndian.PutUint32(
b.Bytes[b.msgStart:], uint32(len(b.Bytes)-b.msgStart))
}
func (b *writeBuffer) Query() []byte {
return b.Bytes[b.msgStart+4 : len(b.Bytes)-1]
}
func (b *writeBuffer) StartParam() {
b.paramStart = len(b.Bytes)
b.Bytes = append(b.Bytes, 0, 0, 0, 0)
}
func (b *writeBuffer) FinishParam() {
binary.BigEndian.PutUint32(
b.Bytes[b.paramStart:], uint32(len(b.Bytes)-b.paramStart-4))
}
var nullParamLength = int32(-1)
func (b *writeBuffer) FinishNullParam() {
binary.BigEndian.PutUint32(
b.Bytes[b.paramStart:], uint32(nullParamLength))
}
func (b *writeBuffer) Write(data []byte) (int, error) {
b.Bytes = append(b.Bytes, data...)
return len(data), nil
}
func (b *writeBuffer) WriteInt16(num int16) {
b.Bytes = append(b.Bytes, 0, 0)
binary.BigEndian.PutUint16(b.Bytes[len(b.Bytes)-2:], uint16(num))
}
func (b *writeBuffer) WriteInt32(num int32) {
b.Bytes = append(b.Bytes, 0, 0, 0, 0)
binary.BigEndian.PutUint32(b.Bytes[len(b.Bytes)-4:], uint32(num))
}
func (b *writeBuffer) WriteString(s string) {
b.Bytes = append(b.Bytes, s...)
b.Bytes = append(b.Bytes, 0)
}
func (b *writeBuffer) WriteBytes(data []byte) {
b.Bytes = append(b.Bytes, data...)
b.Bytes = append(b.Bytes, 0)
}
func (b *writeBuffer) WriteByte(c byte) error {
b.Bytes = append(b.Bytes, c)
return nil
}
func (b *writeBuffer) ReadFrom(r io.Reader) (int64, error) {
n, err := r.Read(b.Bytes[len(b.Bytes):cap(b.Bytes)])
b.Bytes = b.Bytes[:len(b.Bytes)+n]
return int64(n), err
}