[chore] update dependencies, bump to Go 1.19.1 (#826)

* update dependencies, bump Go version to 1.19

* bump test image Go version

* update golangci-lint

* update gotosocial-drone-build

* sign

* linting, go fmt

* update swagger docs

* update swagger docs

* whitespace

* update contributing.md

* fuckin whoopsie doopsie

* linterino, linteroni

* fix followrequest test not starting processor

* fix other api/client tests not starting processor

* fix remaining tests where processor not started

* bump go-runners version

* don't check last-webfingered-at, processor may have updated this

* update swagger command

* update bun to latest version

* fix embed to work the same as before with new bun

Signed-off-by: kim <grufwub@gmail.com>
Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
kim 2022-09-28 18:30:40 +01:00 committed by GitHub
commit a156188b3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
1135 changed files with 258905 additions and 137146 deletions

View file

@ -1,3 +1,21 @@
# 1.13.0 (August 6, 2022)
* Add sslpassword support (Eric McCormack and yun.xu)
* Add prefer-standby target_session_attrs support (sergey.bashilov)
* Fix GSS ErrorResponse handling (Oliver Tan)
# 1.12.1 (May 7, 2022)
* Fix: setting krbspn and krbsrvname in connection string (sireax)
* Add support for Unix sockets on Windows (Eno Compton)
* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim)
# 1.12.0 (April 21, 2022)
* Add pluggable GSSAPI support (Oliver Tan)
* Fix: Consider any "0A000" error a possible cached plan changed error due to locale
* Better match psql fallback behavior with multiple hosts
# 1.11.0 (February 7, 2022)
* Support port in ip from LookupFunc to override config (James Hartig)

View file

@ -78,12 +78,14 @@ func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error)
if err != nil {
return nil, err
}
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if ok {
return saslContinue, nil
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
@ -91,12 +93,14 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
}
type scramClient struct {

View file

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
@ -25,6 +26,7 @@ import (
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
// manually initialized Config will cause ConnectConfig to panic.
@ -41,7 +43,9 @@ type Config struct {
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
@ -61,6 +65,13 @@ type Config struct {
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
@ -98,10 +109,29 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if strings.HasPrefix(host, "/") {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
@ -111,10 +141,10 @@ func NetworkAddress(host string, port uint16) (network, address string) {
return network, address
}
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example DSN
@ -138,21 +168,22 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
@ -172,23 +203,29 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLCConfig.
// TLSConfig.
//
// Other known differences with libpq:
//
// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first.
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions)
}
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
// get the SSL password.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
@ -259,12 +296,23 @@ func ParseConfig(connString string) (*Config, error) {
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"min_read_buffer_size": {},
"service": {},
"servicefile": {},
}
// Adding kerberos configuration
if _, present := settings["krbsrvname"]; present {
config.KerberosSrvName = settings["krbsrvname"]
}
if _, present := settings["krbspn"]; present {
config.KerberosSpn = settings["krbspn"]
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
@ -297,7 +345,7 @@ func ParseConfig(connString string) (*Config, error) {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host)
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
@ -338,7 +386,9 @@ func ParseConfig(connString string) (*Config, error) {
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "any", "prefer-standby":
case "prefer-standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
@ -375,6 +425,7 @@ func parseEnvSettings() map[string]string {
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
@ -561,12 +612,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
// Match libpq default behavior
if sslmode == "" {
@ -654,11 +706,53 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro
}
if sslcert != "" && sslkey != "" {
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
buf, err := ioutil.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
var pemKey []byte
var decryptedKey []byte
var decryptedError error
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
}
//if sslpassword not provided or has decryption error when use it
//try to find sslpassword with callback function
if sslpassword == "" || decryptedError != nil {
if parseConfigOptions.GetSSLPassword != nil {
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
}
if sslpassword == "" {
return nil, fmt.Errorf("unable to find sslpassword")
}
}
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,
}
pemKey = pem.EncodeToMemory(&pemBytes)
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := ioutil.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
cert, err := tls.X509KeyPair(certfile, pemKey)
if err != nil {
return nil, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
@ -781,3 +875,18 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}
return nil
}

View file

@ -1,3 +1,4 @@
//go:build !windows
// +build !windows
package pgconn

View file

@ -219,3 +219,20 @@ func redactURL(u *url.URL) string {
}
return u.String()
}
type NotPreferredError struct {
err error
safeToRetry bool
}
func (e *NotPreferredError) Error() string {
return fmt.Sprintf("standby server not found: %s", e.err.Error())
}
func (e *NotPreferredError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *NotPreferredError) Unwrap() error {
return e.err
}

99
vendor/github.com/jackc/pgconn/krb5.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
package pgconn
import (
"errors"
"fmt"
"github.com/jackc/pgproto3/v2"
)
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGSS NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/otan/gopgkrb5"
//
// func init() {
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
// }
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
newGSS = newGSSArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
func (c *PgConn) gssAuth() error {
if newGSS == nil {
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
}
cli, err := newGSS()
if err != nil {
return err
}
var nextData []byte
if c.config.KerberosSpn != "" {
// Use the supplied SPN if provided.
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if c.config.KerberosSrvName != "" {
service = c.config.KerberosSrvName
}
nextData, err = cli.GetInitToken(c.config.Host, service)
}
if err != nil {
return err
}
for {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
_, err = c.conn.Write(gssResponse.Encode(nil))
if err != nil {
return err
}
resp, err := c.rxGSSContinue()
if err != nil {
return err
}
var done bool
done, nextData, err = cli.Continue(resp.Data)
if err != nil {
return err
}
if done {
break
}
}
return nil
}
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationGSSContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}

View file

@ -99,7 +99,7 @@ type PgConn struct {
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt.
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt.
func Connect(ctx context.Context, connString string) (*PgConn, error) {
config, err := ParseConfig(connString)
if err != nil {
@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be
// used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
if err != nil {
return nil, err
}
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// ParseConfig. ctx can be used to cancel a connect attempt.
//
@ -148,17 +160,36 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false
var fallbackConfig *FallbackConfig
for _, fc := range fallbackConfigs {
pgConn, err = connect(ctx, config, fc)
pgConn, err = connect(ctx, config, fc, false)
if err == nil {
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
ERRCODE_INVALID_PASSWORD := "28P01" // worng password
ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist
if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION {
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION ||
pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*connectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
}
}
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
}
}
@ -182,7 +213,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
for _, fb := range fallbacks {
// skip resolve for unix sockets
if strings.HasPrefix(fb.Host, "/") {
if isAbsolutePath(fb.Host) {
configs = append(configs, &FallbackConfig{
Host: fb.Host,
Port: fb.Port,
@ -222,7 +253,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
return configs, nil
}
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
ignoreNotPreferredErr bool) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config
pgConn.wbuf = make([]byte, 0, wbufLen)
@ -317,7 +349,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.ValidateConnect != nil {
@ -330,6 +367,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err := config.ValidateConnect(ctx, pgConn)
if err != nil {
if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok {
return pgConn, nil
}
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
}

View file

@ -102,10 +102,14 @@ func (c *LRU) StatementErrored(sql string, err error) {
return
}
isInvalidCachedPlanError := pgErr.Severity == "ERROR" &&
pgErr.Code == "0A000" &&
pgErr.Message == "cached plan must not change result type"
if isInvalidCachedPlanError {
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
if possibleInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}

View file

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSS struct{}
func (a *AuthenticationGSS) Backend() {}
func (a *AuthenticationGSS) AuthenticationResponse() {}
func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,67 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSSContinue struct {
Data []byte
}
func (a *AuthenticationGSSContinue) Backend() {}
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}
a.Data = src[4:]
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
a.Data = msg.Data
return nil
}

View file

@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
@ -30,11 +31,10 @@ type Backend struct {
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
@ -115,6 +115,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
if b.bodyLen < 0 {
return nil, errors.New("invalid message with negative body length received")
}
}
var msg FrontendMessage
@ -147,6 +150,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:

View file

@ -48,7 +48,7 @@ func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)

View file

@ -16,6 +16,8 @@ type Frontend struct {
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
@ -77,6 +79,9 @@ func (f *Frontend) Receive() (BackendMessage, error) {
f.msgType = header[0]
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
f.partialMsg = true
if f.bodyLen < 0 {
return nil, errors.New("invalid message with negative body length received")
}
}
msgBody, err := f.cr.Next(f.bodyLen)
@ -178,9 +183,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return nil, errors.New("AuthTypeGSS is unimplemented")
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return nil, errors.New("AuthTypeGSSCont is unimplemented")
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:

48
vendor/github.com/jackc/pgproto3/v2/gss_response.go generated vendored Normal file
View file

@ -0,0 +1,48 @@
package pgproto3
import (
"encoding/json"
"github.com/jackc/pgio"
)
type GSSResponse struct {
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}
func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}

View file

@ -1,3 +1,24 @@
# 1.12.0 (August 6, 2022)
* Add JSONArray (Jakob Ackermann)
* Support Inet from fmt.Stringer and encoding.TextMarshaler (Ville Skyttä)
* Support UUID from fmt.Stringer interface (Lasse Hyldahl Jensen)
* Fix: shopspring-numeric extension does not panic on NaN
* Numeric can be assigned to string
* Fix: Do not send IPv4 networks as IPv4-mapped IPv6 (William Storey)
* Fix: PlanScan for interface{}(nil) (James Hartig)
* Fix: *sql.Scanner for NULL handling (James Hartig)
* Timestamp[tz].Set() supports string (Harmen)
* Fix: Hstore AssignTo with map of *string (Diego Becciolini)
# 1.11.0 (April 21, 2022)
* Add multirange for numeric, int4, and int8 (Vu)
* JSONBArray now supports json.RawMessage (Jens Emil Schulz Østergaard)
* Add RecordArray (WGH)
* Add UnmarshalJSON to pgtype.Int2
* Hstore.Set accepts map[string]Text
# 1.10.0 (February 7, 2022)
* Normalize UTC timestamps to comply with stdlib (Torkel Rogstad)

View file

@ -11,7 +11,7 @@ import (
// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties
// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience
// type for types that do not have an concrete array type.
// type for types that do not have a concrete array type.
type ArrayType struct {
elements []ValueTranscoder
dimensions []ArrayDimension

View file

@ -37,14 +37,14 @@ func (dst *Date) Set(src interface{}) error {
switch value := src.(type) {
case time.Time:
*dst = Date{Time: value, Status: Present}
case string:
return dst.DecodeText(nil, []byte(value))
case *time.Time:
if value == nil {
*dst = Date{Status: Null}
} else {
return dst.Set(*value)
}
case string:
return dst.DecodeText(nil, []byte(value))
case *string:
if value == nil {
*dst = Date{Status: Null}

View file

@ -2,7 +2,7 @@ package pgtype
import "fmt"
// EnumType represents a enum type. While it implements Value, this is only in service of its type conversion duties
// EnumType represents an enum type. While it implements Value, this is only in service of its type conversion duties
// when registered as a data type in a ConnType. It should not be used directly as a Value.
type EnumType struct {
value string

View file

@ -50,6 +50,8 @@ func (dst *Hstore) Set(src interface{}) error {
}
}
*dst = Hstore{Map: m, Status: Present}
case map[string]Text:
*dst = Hstore{Map: value, Status: Present}
default:
return fmt.Errorf("cannot convert %v to Hstore", src)
}
@ -88,7 +90,8 @@ func (src *Hstore) AssignTo(dst interface{}) error {
case Null:
(*v)[k] = nil
case Present:
(*v)[k] = &val.String
str := val.String
(*v)[k] = &str
default:
return fmt.Errorf("cannot decode %#v into %T", src, dst)
}
@ -411,7 +414,7 @@ func parseHstore(s string) (k []string, v []Text, err error) {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after ',', expcting space")
err = errors.New("Found EOS after ',', expecting space")
case (unicode.IsSpace(r)):
r, end = p.Consume()
state = hsKey

View file

@ -2,8 +2,10 @@ package pgtype
import (
"database/sql/driver"
"encoding"
"fmt"
"net"
"strings"
)
// Network address family is dependent on server socket.h value for AF_INET.
@ -47,17 +49,26 @@ func (dst *Inet) Set(src interface{}) error {
case string:
ip, ipnet, err := net.ParseCIDR(value)
if err != nil {
ip = net.ParseIP(value)
ip := net.ParseIP(value)
if ip == nil {
return fmt.Errorf("unable to parse inet address: %s", value)
}
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
if ipv4 := ip.To4(); ipv4 != nil {
ip = ipv4
ipnet.Mask = net.CIDRMask(32, 32)
if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil {
ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)}
} else {
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
ipnet.IP = ip
if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil {
ipnet.IP = ipv4
if len(ipnet.Mask) == 16 {
ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed.
}
}
}
ipnet.IP = ip
*dst = Inet{IPNet: ipnet, Status: Present}
case *net.IPNet:
if value == nil {
@ -78,6 +89,16 @@ func (dst *Inet) Set(src interface{}) error {
return dst.Set(*value)
}
default:
if tv, ok := src.(encoding.TextMarshaler); ok {
text, err := tv.MarshalText()
if err != nil {
return fmt.Errorf("cannot marshal %v: %w", value, err)
}
return dst.Set(string(text))
}
if sv, ok := src.(fmt.Stringer); ok {
return dst.Set(sv.String())
}
if originalSrc, ok := underlyingPtrType(src); ok {
return dst.Set(originalSrc)
}
@ -87,6 +108,25 @@ func (dst *Inet) Set(src interface{}) error {
return nil
}
// Convert the net.IP to IPv4, if appropriate.
//
// When parsing a string to a net.IP using net.ParseIP() and the like, we get a
// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function
// calls To4() to convert them to a 4 byte slice. This is useful as it allows
// users of the net.IP check for IPv4 addresses based on the length and makes
// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6
// addresses.
func maybeGetIPv4(input string, ip net.IP) net.IP {
// Do not do this if the provided input looks like IPv6. This is because
// To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave
// different in some cases.
if strings.Contains(input, ":") {
return nil
}
return ip.To4()
}
func (dst Inet) Get() interface{} {
switch dst.Status {
case Present:
@ -118,6 +158,12 @@ func (src *Inet) AssignTo(dst interface{}) error {
copy(*v, src.IPNet.IP)
return nil
default:
if tv, ok := dst.(encoding.TextUnmarshaler); ok {
if err := tv.UnmarshalText([]byte(src.IPNet.String())); err != nil {
return fmt.Errorf("cannot unmarshal %v to %T: %w", src, dst, err)
}
return nil
}
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
@ -169,7 +215,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
}
if len(src) != 8 && len(src) != 20 {
return fmt.Errorf("Received an invalid size for a inet: %d", len(src))
return fmt.Errorf("Received an invalid size for an inet: %d", len(src))
}
// ignore family

View file

@ -3,6 +3,7 @@ package pgtype
import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"strconv"
@ -302,3 +303,19 @@ func (src Int2) MarshalJSON() ([]byte, error) {
return nil, errBadStatus
}
func (dst *Int2) UnmarshalJSON(b []byte) error {
var n *int16
err := json.Unmarshal(b, &n)
if err != nil {
return err
}
if n == nil {
*dst = Int2{Status: Null}
} else {
*dst = Int2{Int: *n, Status: Present}
}
return nil
}

239
vendor/github.com/jackc/pgtype/int4_multirange.go generated vendored Normal file
View file

@ -0,0 +1,239 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
type Int4multirange struct {
Ranges []Int4range
Status Status
}
func (dst *Int4multirange) Set(src interface{}) error {
//untyped nil and typed nil interfaces are different
if src == nil {
*dst = Int4multirange{Status: Null}
return nil
}
switch value := src.(type) {
case Int4multirange:
*dst = value
case *Int4multirange:
*dst = *value
case string:
return dst.DecodeText(nil, []byte(value))
case []Int4range:
if value == nil {
*dst = Int4multirange{Status: Null}
} else if len(value) == 0 {
*dst = Int4multirange{Status: Present}
} else {
elements := make([]Int4range, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int4multirange{
Ranges: elements,
Status: Present,
}
}
case []*Int4range:
if value == nil {
*dst = Int4multirange{Status: Null}
} else if len(value) == 0 {
*dst = Int4multirange{Status: Present}
} else {
elements := make([]Int4range, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int4multirange{
Ranges: elements,
Status: Present,
}
}
default:
return fmt.Errorf("cannot convert %v to Int4multirange", src)
}
return nil
}
func (dst Int4multirange) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Int4multirange) AssignTo(dst interface{}) error {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Int4multirange) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int4multirange{Status: Null}
return nil
}
utmr, err := ParseUntypedTextMultirange(string(src))
if err != nil {
return err
}
var elements []Int4range
if len(utmr.Elements) > 0 {
elements = make([]Int4range, len(utmr.Elements))
for i, s := range utmr.Elements {
var elem Int4range
elemSrc := []byte(s)
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = Int4multirange{Ranges: elements, Status: Present}
return nil
}
func (dst *Int4multirange) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int4multirange{Status: Null}
return nil
}
rp := 0
numElems := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numElems == 0 {
*dst = Int4multirange{Status: Present}
return nil
}
elements := make([]Int4range, numElems)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err := elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = Int4multirange{Ranges: elements, Status: Present}
return nil
}
func (src Int4multirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, '{')
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Ranges {
if i > 0 {
buf = append(buf, ',')
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
return nil, fmt.Errorf("multi-range does not allow null range")
} else {
buf = append(buf, string(elemBuf)...)
}
}
buf = append(buf, '}')
return buf, nil
}
func (src Int4multirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendInt32(buf, int32(len(src.Ranges)))
for i := range src.Ranges {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Int4multirange) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int4multirange) Value() (driver.Value, error) {
return EncodeValueText(src)
}

239
vendor/github.com/jackc/pgtype/int8_multirange.go generated vendored Normal file
View file

@ -0,0 +1,239 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
type Int8multirange struct {
Ranges []Int8range
Status Status
}
func (dst *Int8multirange) Set(src interface{}) error {
//untyped nil and typed nil interfaces are different
if src == nil {
*dst = Int8multirange{Status: Null}
return nil
}
switch value := src.(type) {
case Int8multirange:
*dst = value
case *Int8multirange:
*dst = *value
case string:
return dst.DecodeText(nil, []byte(value))
case []Int8range:
if value == nil {
*dst = Int8multirange{Status: Null}
} else if len(value) == 0 {
*dst = Int8multirange{Status: Present}
} else {
elements := make([]Int8range, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int8multirange{
Ranges: elements,
Status: Present,
}
}
case []*Int8range:
if value == nil {
*dst = Int8multirange{Status: Null}
} else if len(value) == 0 {
*dst = Int8multirange{Status: Present}
} else {
elements := make([]Int8range, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Int8multirange{
Ranges: elements,
Status: Present,
}
}
default:
return fmt.Errorf("cannot convert %v to Int8multirange", src)
}
return nil
}
func (dst Int8multirange) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Int8multirange) AssignTo(dst interface{}) error {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Int8multirange) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int8multirange{Status: Null}
return nil
}
utmr, err := ParseUntypedTextMultirange(string(src))
if err != nil {
return err
}
var elements []Int8range
if len(utmr.Elements) > 0 {
elements = make([]Int8range, len(utmr.Elements))
for i, s := range utmr.Elements {
var elem Int8range
elemSrc := []byte(s)
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = Int8multirange{Ranges: elements, Status: Present}
return nil
}
func (dst *Int8multirange) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Int8multirange{Status: Null}
return nil
}
rp := 0
numElems := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numElems == 0 {
*dst = Int8multirange{Status: Present}
return nil
}
elements := make([]Int8range, numElems)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err := elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = Int8multirange{Ranges: elements, Status: Present}
return nil
}
func (src Int8multirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, '{')
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Ranges {
if i > 0 {
buf = append(buf, ',')
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
return nil, fmt.Errorf("multi-range does not allow null range")
} else {
buf = append(buf, string(elemBuf)...)
}
}
buf = append(buf, '}')
return buf, nil
}
func (src Int8multirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendInt32(buf, int32(len(src.Ranges)))
for i := range src.Ranges {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Int8multirange) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Int8multirange) Value() (driver.Value, error) {
return EncodeValueText(src)
}

View file

@ -174,7 +174,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error {
}
if len(src) != 16 {
return fmt.Errorf("Received an invalid size for a interval: %d", len(src))
return fmt.Errorf("Received an invalid size for an interval: %d", len(src))
}
microseconds := int64(binary.BigEndian.Uint64(src))

546
vendor/github.com/jackc/pgtype/json_array.go generated vendored Normal file
View file

@ -0,0 +1,546 @@
// Code generated by erb. DO NOT EDIT.
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"reflect"
"github.com/jackc/pgio"
)
type JSONArray struct {
Elements []JSON
Dimensions []ArrayDimension
Status Status
}
func (dst *JSONArray) Set(src interface{}) error {
// untyped nil and typed nil interfaces are different
if src == nil {
*dst = JSONArray{Status: Null}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
// Attempt to match to select common types:
switch value := src.(type) {
case []string:
if value == nil {
*dst = JSONArray{Status: Null}
} else if len(value) == 0 {
*dst = JSONArray{Status: Present}
} else {
elements := make([]JSON, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = JSONArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case [][]byte:
if value == nil {
*dst = JSONArray{Status: Null}
} else if len(value) == 0 {
*dst = JSONArray{Status: Present}
} else {
elements := make([]JSON, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = JSONArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []json.RawMessage:
if value == nil {
*dst = JSONArray{Status: Null}
} else if len(value) == 0 {
*dst = JSONArray{Status: Present}
} else {
elements := make([]JSON, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = JSONArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []JSON:
if value == nil {
*dst = JSONArray{Status: Null}
} else if len(value) == 0 {
*dst = JSONArray{Status: Present}
} else {
*dst = JSONArray{
Elements: value,
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
Status: Present,
}
}
default:
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
reflectedValue := reflect.ValueOf(src)
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
*dst = JSONArray{Status: Null}
return nil
}
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
if !ok {
return fmt.Errorf("cannot find dimensions of %v for JSONArray", src)
}
if elementsLength == 0 {
*dst = JSONArray{Status: Present}
return nil
}
if len(dimensions) == 0 {
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to JSONArray", src)
}
*dst = JSONArray{
Elements: make([]JSON, elementsLength),
Dimensions: dimensions,
Status: Present,
}
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
// Maybe the target was one dimension too far, try again:
if len(dst.Dimensions) > 1 {
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
elementsLength = 0
for _, dim := range dst.Dimensions {
if elementsLength == 0 {
elementsLength = int(dim.Length)
} else {
elementsLength *= int(dim.Length)
}
}
dst.Elements = make([]JSON, elementsLength)
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
return err
}
} else {
return err
}
}
if elementCount != len(dst.Elements) {
return fmt.Errorf("cannot convert %v to JSONArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
}
}
return nil
}
func (dst *JSONArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
switch value.Kind() {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(dst.Dimensions) == dimension {
break
}
valueLen := value.Len()
if int32(valueLen) != dst.Dimensions[dimension].Length {
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
}
for i := 0; i < valueLen; i++ {
var err error
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if !value.CanInterface() {
return 0, fmt.Errorf("cannot convert all values to JSONArray")
}
if err := dst.Elements[index].Set(value.Interface()); err != nil {
return 0, fmt.Errorf("%v in JSONArray", err)
}
index++
return index, nil
}
func (dst JSONArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *JSONArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
if len(src.Dimensions) <= 1 {
// Attempt to match to select common types:
switch v := dst.(type) {
case *[]string:
*v = make([]string, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[][]byte:
*v = make([][]byte, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
case *[]json.RawMessage:
*v = make([]json.RawMessage, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
}
}
// Try to convert to something AssignTo can use directly.
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
value := reflect.ValueOf(dst)
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
switch value.Kind() {
case reflect.Array, reflect.Slice:
default:
return fmt.Errorf("cannot assign %T to %T", src, dst)
}
if len(src.Elements) == 0 {
if value.Kind() == reflect.Slice {
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
return nil
}
}
elementCount, err := src.assignToRecursive(value, 0, 0)
if err != nil {
return err
}
if elementCount != len(src.Elements) {
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
}
return nil
case Null:
return NullAssignTo(dst)
}
return fmt.Errorf("cannot decode %#v into %T", src, dst)
}
func (src *JSONArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
switch kind := value.Kind(); kind {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(src.Dimensions) == dimension {
break
}
length := int(src.Dimensions[dimension].Length)
if reflect.Array == kind {
typ := value.Type()
if typ.Len() != length {
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
}
value.Set(reflect.New(typ).Elem())
} else {
value.Set(reflect.MakeSlice(value.Type(), length, length))
}
var err error
for i := 0; i < length; i++ {
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if len(src.Dimensions) != dimension {
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
}
if !value.CanAddr() {
return 0, fmt.Errorf("cannot assign all values from JSONArray")
}
addr := value.Addr()
if !addr.CanInterface() {
return 0, fmt.Errorf("cannot assign all values from JSONArray")
}
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
return 0, err
}
index++
return index, nil
}
func (dst *JSONArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = JSONArray{Status: Null}
return nil
}
uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}
var elements []JSON
if len(uta.Elements) > 0 {
elements = make([]JSON, len(uta.Elements))
for i, s := range uta.Elements {
var elem JSON
var elemSrc []byte
if s != "NULL" || uta.Quoted[i] {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = JSONArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
return nil
}
func (dst *JSONArray) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = JSONArray{Status: Null}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = JSONArray{Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]JSON, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = JSONArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
func (src JSONArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}
for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}
for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}
return buf, nil
}
func (src JSONArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
arrayHeader := ArrayHeader{
Dimensions: src.Dimensions,
}
if dt, ok := ci.DataTypeForName("json"); ok {
arrayHeader.ElementOID = int32(dt.OID)
} else {
return nil, fmt.Errorf("unable to find oid for type name %v", "json")
}
for i := range src.Elements {
if src.Elements[i].Status == Null {
arrayHeader.ContainsNull = true
break
}
}
buf = arrayHeader.EncodeBinary(ci, buf)
for i := range src.Elements {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *JSONArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src JSONArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
return string(buf), nil
}

View file

@ -5,6 +5,7 @@ package pgtype
import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"reflect"
@ -72,6 +73,25 @@ func (dst *JSONBArray) Set(src interface{}) error {
}
}
case []json.RawMessage:
if value == nil {
*dst = JSONBArray{Status: Null}
} else if len(value) == 0 {
*dst = JSONBArray{Status: Present}
} else {
elements := make([]JSONB, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = JSONBArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []JSONB:
if value == nil {
*dst = JSONBArray{Status: Null}
@ -214,6 +234,15 @@ func (src *JSONBArray) AssignTo(dst interface{}) error {
}
return nil
case *[]json.RawMessage:
*v = make([]json.RawMessage, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
}
}

83
vendor/github.com/jackc/pgtype/multirange.go generated vendored Normal file
View file

@ -0,0 +1,83 @@
package pgtype
import (
"bytes"
"fmt"
)
type UntypedTextMultirange struct {
Elements []string
}
func ParseUntypedTextMultirange(src string) (*UntypedTextMultirange, error) {
utmr := &UntypedTextMultirange{}
utmr.Elements = make([]string, 0)
buf := bytes.NewBufferString(src)
skipWhitespace(buf)
r, _, err := buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid array: %v", err)
}
if r != '{' {
return nil, fmt.Errorf("invalid multirange, expected '{': %v", err)
}
parseValueLoop:
for {
r, _, err = buf.ReadRune()
if err != nil {
return nil, fmt.Errorf("invalid multirange: %v", err)
}
switch r {
case ',': // skip range separator
case '}':
break parseValueLoop
default:
buf.UnreadRune()
value, err := parseRange(buf)
if err != nil {
return nil, fmt.Errorf("invalid multirange value: %v", err)
}
utmr.Elements = append(utmr.Elements, value)
}
}
skipWhitespace(buf)
if buf.Len() > 0 {
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
}
return utmr, nil
}
func parseRange(buf *bytes.Buffer) (string, error) {
s := &bytes.Buffer{}
boundSepRead := false
for {
r, _, err := buf.ReadRune()
if err != nil {
return "", err
}
switch r {
case ',', '}':
if r == ',' && !boundSepRead {
boundSepRead = true
break
}
buf.UnreadRune()
return s.String(), nil
}
s.WriteRune(r)
}
}

239
vendor/github.com/jackc/pgtype/num_multirange.go generated vendored Normal file
View file

@ -0,0 +1,239 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
type Nummultirange struct {
Ranges []Numrange
Status Status
}
func (dst *Nummultirange) Set(src interface{}) error {
//untyped nil and typed nil interfaces are different
if src == nil {
*dst = Nummultirange{Status: Null}
return nil
}
switch value := src.(type) {
case Nummultirange:
*dst = value
case *Nummultirange:
*dst = *value
case string:
return dst.DecodeText(nil, []byte(value))
case []Numrange:
if value == nil {
*dst = Nummultirange{Status: Null}
} else if len(value) == 0 {
*dst = Nummultirange{Status: Present}
} else {
elements := make([]Numrange, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Nummultirange{
Ranges: elements,
Status: Present,
}
}
case []*Numrange:
if value == nil {
*dst = Nummultirange{Status: Null}
} else if len(value) == 0 {
*dst = Nummultirange{Status: Present}
} else {
elements := make([]Numrange, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = Nummultirange{
Ranges: elements,
Status: Present,
}
}
default:
return fmt.Errorf("cannot convert %v to Nummultirange", src)
}
return nil
}
func (dst Nummultirange) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *Nummultirange) AssignTo(dst interface{}) error {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *Nummultirange) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Nummultirange{Status: Null}
return nil
}
utmr, err := ParseUntypedTextMultirange(string(src))
if err != nil {
return err
}
var elements []Numrange
if len(utmr.Elements) > 0 {
elements = make([]Numrange, len(utmr.Elements))
for i, s := range utmr.Elements {
var elem Numrange
elemSrc := []byte(s)
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = Nummultirange{Ranges: elements, Status: Present}
return nil
}
func (dst *Nummultirange) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Nummultirange{Status: Null}
return nil
}
rp := 0
numElems := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numElems == 0 {
*dst = Nummultirange{Status: Present}
return nil
}
elements := make([]Numrange, numElems)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err := elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = Nummultirange{Ranges: elements, Status: Present}
return nil
}
func (src Nummultirange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, '{')
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Ranges {
if i > 0 {
buf = append(buf, ',')
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
return nil, fmt.Errorf("multi-range does not allow null range")
} else {
buf = append(buf, string(elemBuf)...)
}
}
buf = append(buf, '}')
return buf, nil
}
func (src Nummultirange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendInt32(buf, int32(len(src.Ranges)))
for i := range src.Ranges {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *Nummultirange) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Nummultirange) Value() (driver.Value, error) {
return EncodeValueText(src)
}

View file

@ -1,6 +1,7 @@
package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
@ -375,6 +376,12 @@ func (src *Numeric) AssignTo(dst interface{}) error {
return err
}
v.Set(rat)
case *string:
buf, err := encodeNumericText(*src, nil)
if err != nil {
return err
}
*v = string(buf)
default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
@ -792,3 +799,55 @@ func (src Numeric) Value() (driver.Value, error) {
return nil, errUndefined
}
}
func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) {
// if !n.Valid {
// return nil, nil
// }
if n.NaN {
buf = append(buf, "NaN"...)
return buf, nil
} else if n.InfinityModifier == Infinity {
buf = append(buf, "Infinity"...)
return buf, nil
} else if n.InfinityModifier == NegativeInfinity {
buf = append(buf, "-Infinity"...)
return buf, nil
}
buf = append(buf, n.numberTextBytes()...)
return buf, nil
}
// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
intStr := n.Int.String()
buf := &bytes.Buffer{}
exp := int(n.Exp)
if exp > 0 {
buf.WriteString(intStr)
for i := 0; i < exp; i++ {
buf.WriteByte('0')
}
} else if exp < 0 {
if len(intStr) <= -exp {
buf.WriteString("0.")
leadingZeros := -exp - len(intStr)
for i := 0; i < leadingZeros; i++ {
buf.WriteByte('0')
}
buf.WriteString(intStr)
} else if len(intStr) > -exp {
dpPos := len(intStr) + exp
buf.WriteString(intStr[:dpPos])
buf.WriteByte('.')
buf.WriteString(intStr[dpPos:])
}
} else {
buf.WriteString(intStr)
}
return buf.Bytes()
}

View file

@ -26,6 +26,7 @@ const (
XIDOID = 28
CIDOID = 29
JSONOID = 114
JSONArrayOID = 199
PointOID = 600
LsegOID = 601
PathOID = 602
@ -74,12 +75,15 @@ const (
JSONBArrayOID = 3807
DaterangeOID = 3912
Int4rangeOID = 3904
Int4multirangeOID = 4451
NumrangeOID = 3906
NummultirangeOID = 4532
TsrangeOID = 3908
TsrangeArrayOID = 3909
TstzrangeOID = 3910
TstzrangeArrayOID = 3911
Int8rangeOID = 3926
Int8multirangeOID = 4536
)
type Status byte
@ -288,10 +292,13 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Int2{}, Name: "int2", OID: Int2OID})
ci.RegisterDataType(DataType{Value: &Int4{}, Name: "int4", OID: Int4OID})
ci.RegisterDataType(DataType{Value: &Int4range{}, Name: "int4range", OID: Int4rangeOID})
ci.RegisterDataType(DataType{Value: &Int4multirange{}, Name: "int4multirange", OID: Int4multirangeOID})
ci.RegisterDataType(DataType{Value: &Int8{}, Name: "int8", OID: Int8OID})
ci.RegisterDataType(DataType{Value: &Int8range{}, Name: "int8range", OID: Int8rangeOID})
ci.RegisterDataType(DataType{Value: &Int8multirange{}, Name: "int8multirange", OID: Int8multirangeOID})
ci.RegisterDataType(DataType{Value: &Interval{}, Name: "interval", OID: IntervalOID})
ci.RegisterDataType(DataType{Value: &JSON{}, Name: "json", OID: JSONOID})
ci.RegisterDataType(DataType{Value: &JSONArray{}, Name: "_json", OID: JSONArrayOID})
ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID})
ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID})
ci.RegisterDataType(DataType{Value: &Line{}, Name: "line", OID: LineOID})
@ -300,6 +307,7 @@ func NewConnInfo() *ConnInfo {
ci.RegisterDataType(DataType{Value: &Name{}, Name: "name", OID: NameOID})
ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID})
ci.RegisterDataType(DataType{Value: &Numrange{}, Name: "numrange", OID: NumrangeOID})
ci.RegisterDataType(DataType{Value: &Nummultirange{}, Name: "nummultirange", OID: NummultirangeOID})
ci.RegisterDataType(DataType{Value: &OIDValue{}, Name: "oid", OID: OIDOID})
ci.RegisterDataType(DataType{Value: &Path{}, Name: "path", OID: PathOID})
ci.RegisterDataType(DataType{Value: &Point{}, Name: "point", OID: PointOID})
@ -527,8 +535,22 @@ type scanPlanDataTypeSQLScanner DataType
func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
scanner, ok := dst.(sql.Scanner)
if !ok {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
if src == nil {
// Ensure the pointer points to a zero version of the value
dv.Elem().Set(reflect.Zero(dv.Type().Elem()))
return nil
}
dv = dv.Elem()
// If the pointer is to a nil pointer then set that before scanning
if dv.Kind() == reflect.Ptr && dv.IsNil() {
dv.Set(reflect.New(dv.Type().Elem()))
}
scanner = dv.Interface().(sql.Scanner)
}
dt := (*DataType)(plan)
@ -587,7 +609,25 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode
type scanPlanSQLScanner struct{}
func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error {
scanner := dst.(sql.Scanner)
scanner, ok := dst.(sql.Scanner)
if !ok {
dv := reflect.ValueOf(dst)
if dv.Kind() != reflect.Ptr || !dv.Type().Elem().Implements(scannerType) {
newPlan := ci.PlanScan(oid, formatCode, dst)
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
if src == nil {
// Ensure the pointer points to a zero version of the value
dv.Elem().Set(reflect.Zero(dv.Type()))
return nil
}
dv = dv.Elem()
// If the pointer is to a nil pointer then set that before scanning
if dv.Kind() == reflect.Ptr && dv.IsNil() {
dv.Set(reflect.New(dv.Type().Elem()))
}
scanner = dv.Interface().(sql.Scanner)
}
if src == nil {
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
// text format path would be converted to empty string.
@ -755,6 +795,18 @@ func (scanPlanString) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byt
return newPlan.Scan(ci, oid, formatCode, src, dst)
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func isScanner(dst interface{}) bool {
if _, ok := dst.(sql.Scanner); ok {
return true
}
if t := reflect.TypeOf(dst); t != nil && t.Kind() == reflect.Ptr && t.Elem().Implements(scannerType) {
return true
}
return false
}
// PlanScan prepares a plan to scan a value into dst.
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) ScanPlan {
switch formatCode {
@ -819,13 +871,13 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan
}
if dt != nil {
if _, ok := dst.(sql.Scanner); ok {
if isScanner(dst) {
return (*scanPlanDataTypeSQLScanner)(dt)
}
return (*scanPlanDataTypeAssignTo)(dt)
}
if _, ok := dst.(sql.Scanner); ok {
if isScanner(dst) {
return scanPlanSQLScanner{}
}
@ -873,72 +925,76 @@ var nameValues map[string]Value
func init() {
nameValues = map[string]Value{
"_aclitem": &ACLItemArray{},
"_bool": &BoolArray{},
"_bpchar": &BPCharArray{},
"_bytea": &ByteaArray{},
"_cidr": &CIDRArray{},
"_date": &DateArray{},
"_float4": &Float4Array{},
"_float8": &Float8Array{},
"_inet": &InetArray{},
"_int2": &Int2Array{},
"_int4": &Int4Array{},
"_int8": &Int8Array{},
"_numeric": &NumericArray{},
"_text": &TextArray{},
"_timestamp": &TimestampArray{},
"_timestamptz": &TimestamptzArray{},
"_uuid": &UUIDArray{},
"_varchar": &VarcharArray{},
"_jsonb": &JSONBArray{},
"aclitem": &ACLItem{},
"bit": &Bit{},
"bool": &Bool{},
"box": &Box{},
"bpchar": &BPChar{},
"bytea": &Bytea{},
"char": &QChar{},
"cid": &CID{},
"cidr": &CIDR{},
"circle": &Circle{},
"date": &Date{},
"daterange": &Daterange{},
"float4": &Float4{},
"float8": &Float8{},
"hstore": &Hstore{},
"inet": &Inet{},
"int2": &Int2{},
"int4": &Int4{},
"int4range": &Int4range{},
"int8": &Int8{},
"int8range": &Int8range{},
"interval": &Interval{},
"json": &JSON{},
"jsonb": &JSONB{},
"line": &Line{},
"lseg": &Lseg{},
"macaddr": &Macaddr{},
"name": &Name{},
"numeric": &Numeric{},
"numrange": &Numrange{},
"oid": &OIDValue{},
"path": &Path{},
"point": &Point{},
"polygon": &Polygon{},
"record": &Record{},
"text": &Text{},
"tid": &TID{},
"timestamp": &Timestamp{},
"timestamptz": &Timestamptz{},
"tsrange": &Tsrange{},
"_tsrange": &TsrangeArray{},
"tstzrange": &Tstzrange{},
"_tstzrange": &TstzrangeArray{},
"unknown": &Unknown{},
"uuid": &UUID{},
"varbit": &Varbit{},
"varchar": &Varchar{},
"xid": &XID{},
"_aclitem": &ACLItemArray{},
"_bool": &BoolArray{},
"_bpchar": &BPCharArray{},
"_bytea": &ByteaArray{},
"_cidr": &CIDRArray{},
"_date": &DateArray{},
"_float4": &Float4Array{},
"_float8": &Float8Array{},
"_inet": &InetArray{},
"_int2": &Int2Array{},
"_int4": &Int4Array{},
"_int8": &Int8Array{},
"_numeric": &NumericArray{},
"_text": &TextArray{},
"_timestamp": &TimestampArray{},
"_timestamptz": &TimestamptzArray{},
"_uuid": &UUIDArray{},
"_varchar": &VarcharArray{},
"_json": &JSONArray{},
"_jsonb": &JSONBArray{},
"aclitem": &ACLItem{},
"bit": &Bit{},
"bool": &Bool{},
"box": &Box{},
"bpchar": &BPChar{},
"bytea": &Bytea{},
"char": &QChar{},
"cid": &CID{},
"cidr": &CIDR{},
"circle": &Circle{},
"date": &Date{},
"daterange": &Daterange{},
"float4": &Float4{},
"float8": &Float8{},
"hstore": &Hstore{},
"inet": &Inet{},
"int2": &Int2{},
"int4": &Int4{},
"int4range": &Int4range{},
"int4multirange": &Int4multirange{},
"int8": &Int8{},
"int8range": &Int8range{},
"int8multirange": &Int8multirange{},
"interval": &Interval{},
"json": &JSON{},
"jsonb": &JSONB{},
"line": &Line{},
"lseg": &Lseg{},
"macaddr": &Macaddr{},
"name": &Name{},
"numeric": &Numeric{},
"numrange": &Numrange{},
"nummultirange": &Nummultirange{},
"oid": &OIDValue{},
"path": &Path{},
"point": &Point{},
"polygon": &Polygon{},
"record": &Record{},
"text": &Text{},
"tid": &TID{},
"timestamp": &Timestamp{},
"timestamptz": &Timestamptz{},
"tsrange": &Tsrange{},
"_tsrange": &TsrangeArray{},
"tstzrange": &Tstzrange{},
"_tstzrange": &TstzrangeArray{},
"unknown": &Unknown{},
"uuid": &UUID{},
"varbit": &Varbit{},
"varchar": &Varchar{},
"xid": &XID{},
}
}

View file

@ -6,7 +6,7 @@ import (
)
// Record is the generic PostgreSQL record type such as is created with the
// "row" function. Record only implements BinaryEncoder and Value. The text
// "row" function. Record only implements BinaryDecoder and Value. The text
// format output format from PostgreSQL does not include type information and is
// therefore impossible to decode. No encoders are implemented because
// PostgreSQL does not support input of generic records.

318
vendor/github.com/jackc/pgtype/record_array.go generated vendored Normal file
View file

@ -0,0 +1,318 @@
// Code generated by erb. DO NOT EDIT.
package pgtype
import (
"encoding/binary"
"fmt"
"reflect"
)
type RecordArray struct {
Elements []Record
Dimensions []ArrayDimension
Status Status
}
func (dst *RecordArray) Set(src interface{}) error {
// untyped nil and typed nil interfaces are different
if src == nil {
*dst = RecordArray{Status: Null}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
// Attempt to match to select common types:
switch value := src.(type) {
case [][]Value:
if value == nil {
*dst = RecordArray{Status: Null}
} else if len(value) == 0 {
*dst = RecordArray{Status: Present}
} else {
elements := make([]Record, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = RecordArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}
case []Record:
if value == nil {
*dst = RecordArray{Status: Null}
} else if len(value) == 0 {
*dst = RecordArray{Status: Present}
} else {
*dst = RecordArray{
Elements: value,
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
Status: Present,
}
}
default:
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
reflectedValue := reflect.ValueOf(src)
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
*dst = RecordArray{Status: Null}
return nil
}
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
if !ok {
return fmt.Errorf("cannot find dimensions of %v for RecordArray", src)
}
if elementsLength == 0 {
*dst = RecordArray{Status: Present}
return nil
}
if len(dimensions) == 0 {
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return fmt.Errorf("cannot convert %v to RecordArray", src)
}
*dst = RecordArray{
Elements: make([]Record, elementsLength),
Dimensions: dimensions,
Status: Present,
}
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
// Maybe the target was one dimension too far, try again:
if len(dst.Dimensions) > 1 {
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
elementsLength = 0
for _, dim := range dst.Dimensions {
if elementsLength == 0 {
elementsLength = int(dim.Length)
} else {
elementsLength *= int(dim.Length)
}
}
dst.Elements = make([]Record, elementsLength)
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
if err != nil {
return err
}
} else {
return err
}
}
if elementCount != len(dst.Elements) {
return fmt.Errorf("cannot convert %v to RecordArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
}
}
return nil
}
func (dst *RecordArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
switch value.Kind() {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(dst.Dimensions) == dimension {
break
}
valueLen := value.Len()
if int32(valueLen) != dst.Dimensions[dimension].Length {
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
}
for i := 0; i < valueLen; i++ {
var err error
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if !value.CanInterface() {
return 0, fmt.Errorf("cannot convert all values to RecordArray")
}
if err := dst.Elements[index].Set(value.Interface()); err != nil {
return 0, fmt.Errorf("%v in RecordArray", err)
}
index++
return index, nil
}
func (dst RecordArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *RecordArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
if len(src.Dimensions) <= 1 {
// Attempt to match to select common types:
switch v := dst.(type) {
case *[][]Value:
*v = make([][]Value, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil
}
}
// Try to convert to something AssignTo can use directly.
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
// Fallback to reflection if an optimised match was not found.
// The reflection is necessary for arrays and multidimensional slices,
// but it comes with a 20-50% performance penalty for large arrays/slices
value := reflect.ValueOf(dst)
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
switch value.Kind() {
case reflect.Array, reflect.Slice:
default:
return fmt.Errorf("cannot assign %T to %T", src, dst)
}
if len(src.Elements) == 0 {
if value.Kind() == reflect.Slice {
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
return nil
}
}
elementCount, err := src.assignToRecursive(value, 0, 0)
if err != nil {
return err
}
if elementCount != len(src.Elements) {
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
}
return nil
case Null:
return NullAssignTo(dst)
}
return fmt.Errorf("cannot decode %#v into %T", src, dst)
}
func (src *RecordArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
switch kind := value.Kind(); kind {
case reflect.Array:
fallthrough
case reflect.Slice:
if len(src.Dimensions) == dimension {
break
}
length := int(src.Dimensions[dimension].Length)
if reflect.Array == kind {
typ := value.Type()
if typ.Len() != length {
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
}
value.Set(reflect.New(typ).Elem())
} else {
value.Set(reflect.MakeSlice(value.Type(), length, length))
}
var err error
for i := 0; i < length; i++ {
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
if err != nil {
return 0, err
}
}
return index, nil
}
if len(src.Dimensions) != dimension {
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
}
if !value.CanAddr() {
return 0, fmt.Errorf("cannot assign all values from RecordArray")
}
addr := value.Addr()
if !addr.CanInterface() {
return 0, fmt.Errorf("cannot assign all values from RecordArray")
}
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
return 0, err
}
index++
return index, nil
}
func (dst *RecordArray) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = RecordArray{Status: Null}
return nil
}
var arrayHeader ArrayHeader
rp, err := arrayHeader.DecodeBinary(ci, src)
if err != nil {
return err
}
if len(arrayHeader.Dimensions) == 0 {
*dst = RecordArray{Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}
elementCount := arrayHeader.Dimensions[0].Length
for _, d := range arrayHeader.Dimensions[1:] {
elementCount *= d.Length
}
elements := make([]Record, elementCount)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err = elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = RecordArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
return nil
}

View file

@ -46,6 +46,14 @@ func (dst *Timestamp) Set(src interface{}) error {
} else {
return dst.Set(*value)
}
case string:
return dst.DecodeText(nil, []byte(value))
case *string:
if value == nil {
*dst = Timestamp{Status: Null}
} else {
return dst.Set(*value)
}
case InfinityModifier:
*dst = Timestamp{InfinityModifier: value, Status: Present}
default:

View file

@ -48,6 +48,14 @@ func (dst *Timestamptz) Set(src interface{}) error {
} else {
return dst.Set(*value)
}
case string:
return dst.DecodeText(nil, []byte(value))
case *string:
if value == nil {
*dst = Timestamptz{Status: Null}
} else {
return dst.Set(*value)
}
case InfinityModifier:
*dst = Timestamptz{InfinityModifier: value, Status: Present}
default:

View file

@ -1,5 +1,17 @@
// Code generated by erb. DO NOT EDIT.
<%
# defaults when not explicitly set on command line
binary_format ||= "true"
text_format ||= "true"
text_null ||= "NULL"
encode_binary ||= binary_format
decode_binary ||= binary_format
%>
package pgtype
import (
@ -279,6 +291,7 @@ func (src *<%= pgtype_array_type %>) assignToRecursive(value reflect.Value, inde
return index, nil
}
<% if text_format == "true" %>
func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = <%= pgtype_array_type %>{Status: Null}
@ -314,8 +327,9 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error
return nil
}
<% end %>
<% if binary_format == "true" %>
<% if decode_binary == "true" %>
func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = <%= pgtype_array_type %>{Status: Null}
@ -359,6 +373,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro
}
<% end %>
<% if text_format == "true" %>
func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
@ -415,8 +430,9 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte
return buf, nil
}
<% end %>
<% if binary_format == "true" %>
<% if encode_binary == "true" %>
func (src <%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
@ -462,6 +478,7 @@ func (src <%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte
}
<% end %>
<% if text_format == "true" %>
// Scan implements the database/sql Scanner interface.
func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error {
if src == nil {
@ -492,3 +509,4 @@ func (src <%= pgtype_array_type %>) Value() (driver.Value, error) {
return string(buf), nil
}
<% end %>

View file

@ -1,28 +1,31 @@
erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go
erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go
erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go
erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go
erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go
erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go
erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange text_null=NULL binary_format=true typed_array.go.erb > tstzrange_array.go
erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange text_null=NULL binary_format=true typed_array.go.erb > tsrange_array.go
erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go
erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go
erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go
erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go
erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr text_null=NULL binary_format=true typed_array.go.erb > macaddr_array.go
erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go
erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text text_null=NULL binary_format=true typed_array.go.erb > text_array.go
erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar text_null=NULL binary_format=true typed_array.go.erb > varchar_array.go
erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar text_null=NULL binary_format=true typed_array.go.erb > bpchar_array.go
erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go
erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go
erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go
erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go
erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go
erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte element_type_name=jsonb text_null=NULL binary_format=true typed_array.go.erb > jsonb_array.go
erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int2 typed_array.go.erb > int2_array.go
erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int4 typed_array.go.erb > int4_array.go
erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int16,[]*int16,[]uint16,[]*uint16,[]int32,[]*int32,[]uint32,[]*uint32,[]int64,[]*int64,[]uint64,[]*uint64,[]int,[]*int,[]uint,[]*uint element_type_name=int8 typed_array.go.erb > int8_array.go
erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool,[]*bool element_type_name=bool typed_array.go.erb > bool_array.go
erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time,[]*time.Time element_type_name=date typed_array.go.erb > date_array.go
erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time,[]*time.Time element_type_name=timestamptz typed_array.go.erb > timestamptz_array.go
erb pgtype_array_type=TstzrangeArray pgtype_element_type=Tstzrange go_array_types=[]Tstzrange element_type_name=tstzrange typed_array.go.erb > tstzrange_array.go
erb pgtype_array_type=TsrangeArray pgtype_element_type=Tsrange go_array_types=[]Tsrange element_type_name=tsrange typed_array.go.erb > tsrange_array.go
erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time,[]*time.Time element_type_name=timestamp typed_array.go.erb > timestamp_array.go
erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32,[]*float32 element_type_name=float4 typed_array.go.erb > float4_array.go
erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64,[]*float64 element_type_name=float8 typed_array.go.erb > float8_array.go
erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=inet typed_array.go.erb > inet_array.go
erb pgtype_array_type=MacaddrArray pgtype_element_type=Macaddr go_array_types=[]net.HardwareAddr,[]*net.HardwareAddr element_type_name=macaddr typed_array.go.erb > macaddr_array.go
erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP,[]*net.IP element_type_name=cidr typed_array.go.erb > cidr_array.go
erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string,[]*string element_type_name=text typed_array.go.erb > text_array.go
erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string,[]*string element_type_name=varchar typed_array.go.erb > varchar_array.go
erb pgtype_array_type=BPCharArray pgtype_element_type=BPChar go_array_types=[]string,[]*string element_type_name=bpchar typed_array.go.erb > bpchar_array.go
erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea typed_array.go.erb > bytea_array.go
erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string,[]*string element_type_name=aclitem binary_format=false typed_array.go.erb > aclitem_array.go
erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore typed_array.go.erb > hstore_array.go
erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]*float32,[]float64,[]*float64,[]int64,[]*int64,[]uint64,[]*uint64 element_type_name=numeric typed_array.go.erb > numeric_array.go
erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string,[]*string element_type_name=uuid typed_array.go.erb > uuid_array.go
erb pgtype_array_type=JSONArray pgtype_element_type=JSON go_array_types=[]string,[][]byte,[]json.RawMessage element_type_name=json typed_array.go.erb > json_array.go
erb pgtype_array_type=JSONBArray pgtype_element_type=JSONB go_array_types=[]string,[][]byte,[]json.RawMessage element_type_name=jsonb typed_array.go.erb > jsonb_array.go
# While the binary format is theoretically possible it is only practical to use the text format.
erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string text_null=NULL binary_format=false typed_array.go.erb > enum_array.go
erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string,[]*string binary_format=false typed_array.go.erb > enum_array.go
erb pgtype_array_type=RecordArray pgtype_element_type=Record go_array_types=[][]Value element_type_name=record text_null=NULL encode_binary=false text_format=false typed_array.go.erb > record_array.go
goimports -w *_array.go

239
vendor/github.com/jackc/pgtype/typed_multirange.go.erb generated vendored Normal file
View file

@ -0,0 +1,239 @@
package pgtype
import (
"database/sql/driver"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
type <%= multirange_type %> struct {
Ranges []<%= range_type %>
Status Status
}
func (dst *<%= multirange_type %>) Set(src interface{}) error {
//untyped nil and typed nil interfaces are different
if src == nil {
*dst = <%= multirange_type %>{Status: Null}
return nil
}
switch value := src.(type) {
case <%= multirange_type %>:
*dst = value
case *<%= multirange_type %>:
*dst = *value
case string:
return dst.DecodeText(nil, []byte(value))
case []<%= range_type %>:
if value == nil {
*dst = <%= multirange_type %>{Status: Null}
} else if len(value) == 0 {
*dst = <%= multirange_type %>{Status: Present}
} else {
elements := make([]<%= range_type %>, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = <%= multirange_type %>{
Ranges: elements,
Status: Present,
}
}
case []*<%= range_type %>:
if value == nil {
*dst = <%= multirange_type %>{Status: Null}
} else if len(value) == 0 {
*dst = <%= multirange_type %>{Status: Present}
} else {
elements := make([]<%= range_type %>, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = <%= multirange_type %>{
Ranges: elements,
Status: Present,
}
}
default:
return fmt.Errorf("cannot convert %v to <%= multirange_type %>", src)
}
return nil
}
func (dst <%= multirange_type %>) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}
func (src *<%= multirange_type %>) AssignTo(dst interface{}) error {
return fmt.Errorf("cannot assign %v to %T", src, dst)
}
func (dst *<%= multirange_type %>) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = <%= multirange_type %>{Status: Null}
return nil
}
utmr, err := ParseUntypedTextMultirange(string(src))
if err != nil {
return err
}
var elements []<%= range_type %>
if len(utmr.Elements) > 0 {
elements = make([]<%= range_type %>, len(utmr.Elements))
for i, s := range utmr.Elements {
var elem <%= range_type %>
elemSrc := []byte(s)
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}
elements[i] = elem
}
}
*dst = <%= multirange_type %>{Ranges: elements, Status: Present}
return nil
}
func (dst *<%= multirange_type %>) DecodeBinary(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = <%= multirange_type %>{Status: Null}
return nil
}
rp := 0
numElems := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if numElems == 0 {
*dst = <%= multirange_type %>{Status: Present}
return nil
}
elements := make([]<%= range_type %>, numElems)
for i := range elements {
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
var elemSrc []byte
if elemLen >= 0 {
elemSrc = src[rp : rp+elemLen]
rp += elemLen
}
err := elements[i].DecodeBinary(ci, elemSrc)
if err != nil {
return err
}
}
*dst = <%= multirange_type %>{Ranges: elements, Status: Present}
return nil
}
func (src <%= multirange_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = append(buf, '{')
inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Ranges {
if i > 0 {
buf = append(buf, ',')
}
elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
return nil, fmt.Errorf("multi-range does not allow null range")
} else {
buf = append(buf, string(elemBuf)...)
}
}
buf = append(buf, '}')
return buf, nil
}
func (src <%= multirange_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}
buf = pgio.AppendInt32(buf, int32(len(src.Ranges)))
for i := range src.Ranges {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
elemBuf, err := src.Ranges[i].EncodeBinary(ci, buf)
if err != nil {
return nil, err
}
if elemBuf != nil {
buf = elemBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
}
return buf, nil
}
// Scan implements the database/sql Scanner interface.
func (dst *<%= multirange_type %>) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}
switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src <%= multirange_type %>) Value() (driver.Value, error) {
return EncodeValueText(src)
}

View file

@ -0,0 +1,8 @@
erb range_type=Numrange multirange_type=Nummultirange typed_multirange.go.erb > num_multirange.go
erb range_type=Int4range multirange_type=Int4multirange typed_multirange.go.erb > int4_multirange.go
erb range_type=Int8range multirange_type=Int8multirange typed_multirange.go.erb > int8_multirange.go
# TODO
# erb range_type=Tsrange multirange_type=Tsmultirange typed_multirange.go.erb > ts_multirange.go
# erb range_type=Tstzrange multirange_type=Tstzmultirange typed_multirange.go.erb > tstz_multirange.go
# erb range_type=Daterange multirange_type=Datemultirange typed_multirange.go.erb > date_multirange.go
goimports -w *multirange.go

View file

@ -18,14 +18,15 @@ func (dst *UUID) Set(src interface{}) error {
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
switch value := src.(type) {
case interface{ Get() interface{} }:
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
switch value := src.(type) {
case fmt.Stringer:
value2 := value.String()
return dst.Set(value2)
case [16]byte:
*dst = UUID{Bytes: value, Status: Present}
case []byte:

View file

@ -1,3 +1,38 @@
# 4.17.2 (September 3, 2022)
* Fix panic when logging batch error (Tom Möller)
# 4.17.1 (August 27, 2022)
* Upgrade puddle to v1.3.0 - fixes context failing to cancel Acquire when acquire is creating resource which was introduced in v4.17.0 (James Hartig)
* Fix atomic alignment on 32-bit platforms
# 4.17.0 (August 6, 2022)
* Upgrade pgconn to v1.13.0
* Upgrade pgproto3 to v2.3.1
* Upgrade pgtype to v1.12.0
* Allow background pool connections to continue even if cause is canceled (James Hartig)
* Add LoggerFunc (Gabor Szabad)
* pgxpool: health check should avoid going below minConns (James Hartig)
* Add pgxpool.Conn.Hijack()
* Logging improvements (Stepan Rabotkin)
# 4.16.1 (May 7, 2022)
* Upgrade pgconn to v1.12.1
* Fix explicitly prepared statements with describe statement cache mode
# 4.16.0 (April 21, 2022)
* Upgrade pgconn to v1.12.0
* Upgrade pgproto3 to v2.3.0
* Upgrade pgtype to v1.11.0
* Fix: Do not panic when context cancelled while getting statement from cache.
* Fix: Less memory pinning from old Rows.
* Fix: Support '\r' line ending when sanitizing SQL comment.
* Add pluggable GSSAPI support (Oliver Tan)
# 4.15.0 (February 7, 2022)
* Upgrade to pgconn v1.11.0

View file

@ -1,6 +1,11 @@
[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v4)
[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx)
---
This is the stable `v4` release. `v5` is now in beta testing with final release expected in September. See https://github.com/jackc/pgx/issues/1273 for more information. Please consider testing `v5`.
---
# pgx - PostgreSQL Driver and Toolkit
pgx is a pure Go driver and toolkit for PostgreSQL.
@ -98,26 +103,6 @@ There are three areas in particular where pgx can provide a significant performa
perform nearly 3x the number of queries per second.
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
## Comparison with Alternatives
* [pq](http://godoc.org/github.com/lib/pq)
* [go-pg](https://github.com/go-pg/pg)
For prepared queries with small sets of simple data types, all drivers will have have similar performance. However, if prepared statements aren't being explicitly used, pgx can have a significant performance advantage due to automatic statement preparation.
pgx also can perform better when using PostgreSQL-specific data types or query batching. See
[go_db_bench](https://github.com/jackc/go_db_bench) for some database driver benchmarks.
### Compatibility with `database/sql`
pq is exclusively used with `database/sql`. go-pg does not use `database/sql` at all. pgx supports `database/sql` as well as
its own interface.
### Level of access, ORM
go-pg is a PostgreSQL client and ORM. It includes many features that traditionally sit above the database driver, such as ORM, struct mapping, soft deletes, schema migrations, and sharding support.
pgx is "closer to the metal" and such abstractions are beyond the scope of the pgx project, which first and foremost, aims to be a performant driver and toolkit.
## Testing
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment
@ -201,3 +186,11 @@ pgerrcode contains constants for the PostgreSQL error codes.
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
Library for scanning data from a database into Go structs and more.
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support.
### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
Adds support for [`github.com/google/uuid`](https://github.com/google/uuid).

View file

@ -73,9 +73,8 @@ type Conn struct {
connInfo *pgtype.ConnInfo
wbuf []byte
preallocatedRows []connRows
eqb extendedQueryBuilder
wbuf []byte
eqb extendedQueryBuilder
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
@ -117,14 +116,14 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
// does. In addition, it accepts the following options:
//
// statement_cache_capacity
// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512.
// statement_cache_capacity
// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512.
//
// statement_cache_mode
// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server.
// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the
// server. "describe" is primarily useful when the environment does not allow prepared statements such as when
// running a connection pooler like PgBouncer. Default: "prepare"
// statement_cache_mode
// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server.
// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the
// server. "describe" is primarily useful when the environment does not allow prepared statements such as when
// running a connection pooler like PgBouncer. Default: "prepare"
//
// prefer_simple_protocol
// Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false
@ -366,30 +365,6 @@ func (c *Conn) Ping(ctx context.Context) error {
return err
}
func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) {
if err != nil {
return nil, err
}
defer rows.Close()
nameOIDs := make(map[string]uint32, 256)
for rows.Next() {
var oid uint32
var name pgtype.Text
if err = rows.Scan(&oid, &name); err != nil {
return nil, err
}
nameOIDs[name.String] = oid
}
if err = rows.Err(); err != nil {
return nil, err
}
return nameOIDs, err
}
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
// PostgreSQL connection than pgx exposes.
//
@ -414,7 +389,8 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
endTime := time.Now()
c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)})
}
return commandTag, err
}
@ -537,12 +513,7 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
}
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64)
}
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
r := &connRows{}
r.ctx = ctx
r.logger = c
@ -674,7 +645,7 @@ optionLoop:
resultFormats = c.eqb.resultFormats
}
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe {
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
@ -739,6 +710,8 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
startTime := time.Now()
simpleProtocol := c.config.PreferSimpleProtocol
var sb strings.Builder
if simpleProtocol {
@ -797,24 +770,23 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
// the stmtCache was prefilled from distinctUnpreparedQueries above so we are guaranteed no errors
panic("BUG: unexpected error from stmtCache")
return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")})
}
args, err := convertDriverValuers(bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
}
for i := range args {
err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err})
}
}
@ -833,13 +805,30 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mrr := c.pgConn.ExecBatch(ctx, batch)
return &batchResults{
return c.logBatchResults(ctx, startTime, &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
})
}
func (c *Conn) logBatchResults(ctx context.Context, startTime time.Time, results *batchResults) BatchResults {
if results.err != nil {
if c.shouldLog(LogLevelError) {
endTime := time.Now()
c.log(ctx, LogLevelError, "SendBatch", map[string]interface{}{"err": results.err, "time": endTime.Sub(startTime)})
}
return results
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(ctx, LogLevelInfo, "SendBatch", map[string]interface{}{"batchLen": results.b.Len(), "time": endTime.Sub(startTime)})
}
return results
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {

View file

@ -153,13 +153,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
<-doneChan
rowsAffected := commandTag.RowsAffected()
endTime := time.Now()
if err == nil {
if ct.conn.shouldLog(LogLevelInfo) {
endTime := time.Now()
ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
}
} else if ct.conn.shouldLog(LogLevelError) {
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames})
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
}
return rowsAffected, err

View file

@ -246,7 +246,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n':
case '\n', '\r':
return rawState
case utf8.RuneError:
if l.pos-l.start > 0 {

View file

@ -56,10 +56,10 @@ func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error {
// A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized
// in. It uses the context it was initialized with for all operations. It implements these interfaces:
//
// io.Writer
// io.Reader
// io.Seeker
// io.Closer
// io.Writer
// io.Reader
// io.Seeker
// io.Closer
type LargeObject struct {
ctx context.Context
tx Tx
@ -108,13 +108,13 @@ func (o *LargeObject) Tell() (n int64, err error) {
return n, err
}
// Trunctes the large object to size.
// Truncate the large object to size.
func (o *LargeObject) Truncate(size int64) (err error) {
_, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size)
return err
}
// Close closees the large object descriptor.
// Close the large object descriptor.
func (o *LargeObject) Close() error {
_, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd)
return err

View file

@ -47,9 +47,18 @@ type Logger interface {
Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
}
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
// Log delegates the logging request to the wrapped function
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
f(ctx, level, msg, data)
}
// LogLevelFromString converts log level string to constant
//
// Valid levels:
//
// trace
// debug
// info

View file

@ -143,14 +143,15 @@ func (rows *connRows) Close() {
}
if rows.logger != nil {
endTime := time.Now()
if rows.err == nil {
if rows.logger.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
}
} else {
if rows.logger.shouldLog(LogLevelError) {
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)})
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)})
}
if rows.err != nil && rows.conn.stmtcache != nil {
rows.conn.stmtcache.StatementErrored(rows.sql, rows.err)

View file

@ -163,7 +163,7 @@ func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) err
return nil
}
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector {
c := connector{
ConnConfig: config,
BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default
@ -175,7 +175,11 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
for _, opt := range opts {
opt(&c)
}
return c
}
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
c := GetConnector(config, opts...)
return sql.OpenDB(c)
}

32
vendor/github.com/jackc/pgx/v4/tx.go generated vendored
View file

@ -192,7 +192,7 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
return nil, err
}
return &dbSavepoint{tx: tx, savepointNum: tx.savepointNum}, nil
return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil
}
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
@ -329,15 +329,15 @@ func (tx *dbTx) Conn() *Conn {
return tx.conn
}
// dbSavepoint represents a nested transaction implemented by a savepoint.
type dbSavepoint struct {
// dbSimulatedNestedTx represents a simulated nested transaction implemented by a savepoint.
type dbSimulatedNestedTx struct {
tx Tx
savepointNum int64
closed bool
}
// Begin starts a pseudo nested transaction implemented with a savepoint.
func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) {
if sp.closed {
return nil, ErrTxClosed
}
@ -345,7 +345,7 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) {
return sp.tx.Begin(ctx)
}
func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if sp.closed {
return ErrTxClosed
}
@ -354,7 +354,7 @@ func (sp *dbSavepoint) BeginFunc(ctx context.Context, f func(Tx) error) (err err
}
// Commit releases the savepoint essentially committing the pseudo nested transaction.
func (sp *dbSavepoint) Commit(ctx context.Context) error {
func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error {
if sp.closed {
return ErrTxClosed
}
@ -367,7 +367,7 @@ func (sp *dbSavepoint) Commit(ctx context.Context) error {
// Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return
// ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback()
// is safe even if sp.Commit() will be called first in a non-error condition.
func (sp *dbSavepoint) Rollback(ctx context.Context) error {
func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error {
if sp.closed {
return ErrTxClosed
}
@ -378,7 +378,7 @@ func (sp *dbSavepoint) Rollback(ctx context.Context) error {
}
// Exec delegates to the underlying Tx
func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
if sp.closed {
return nil, ErrTxClosed
}
@ -387,7 +387,7 @@ func (sp *dbSavepoint) Exec(ctx context.Context, sql string, arguments ...interf
}
// Prepare delegates to the underlying Tx
func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
if sp.closed {
return nil, ErrTxClosed
}
@ -396,7 +396,7 @@ func (sp *dbSavepoint) Prepare(ctx context.Context, name, sql string) (*pgconn.S
}
// Query delegates to the underlying Tx
func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
if sp.closed {
// Because checking for errors can be deferred to the *Rows, build one with the error
err := ErrTxClosed
@ -407,13 +407,13 @@ func (sp *dbSavepoint) Query(ctx context.Context, sql string, args ...interface{
}
// QueryRow delegates to the underlying Tx
func (sp *dbSavepoint) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
rows, _ := sp.Query(ctx, sql, args...)
return (*connRow)(rows.(*connRows))
}
// QueryFunc delegates to the underlying Tx.
func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
if sp.closed {
return nil, ErrTxClosed
}
@ -422,7 +422,7 @@ func (sp *dbSavepoint) QueryFunc(ctx context.Context, sql string, args []interfa
}
// CopyFrom delegates to the underlying *Conn
func (sp *dbSavepoint) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
if sp.closed {
return 0, ErrTxClosed
}
@ -431,7 +431,7 @@ func (sp *dbSavepoint) CopyFrom(ctx context.Context, tableName Identifier, colum
}
// SendBatch delegates to the underlying *Conn
func (sp *dbSavepoint) SendBatch(ctx context.Context, b *Batch) BatchResults {
func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults {
if sp.closed {
return &batchResults{err: ErrTxClosed}
}
@ -439,10 +439,10 @@ func (sp *dbSavepoint) SendBatch(ctx context.Context, b *Batch) BatchResults {
return sp.tx.SendBatch(ctx, b)
}
func (sp *dbSavepoint) LargeObjects() LargeObjects {
func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects {
return LargeObjects{tx: sp}
}
func (sp *dbSavepoint) Conn() *Conn {
func (sp *dbSimulatedNestedTx) Conn() *Conn {
return sp.tx.Conn()
}