mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-30 02:16:14 -06:00
233 lines
5 KiB
Go
233 lines
5 KiB
Go
package pgdriver
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Config struct {
|
|
// Network type, either tcp or unix.
|
|
// Default is tcp.
|
|
Network string
|
|
// TCP host:port or Unix socket depending on Network.
|
|
Addr string
|
|
// Dial timeout for establishing new connections.
|
|
// Default is 5 seconds.
|
|
DialTimeout time.Duration
|
|
// Dialer creates new network connection and has priority over
|
|
// Network and Addr options.
|
|
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
|
|
|
// TLS config for secure connections.
|
|
TLSConfig *tls.Config
|
|
|
|
User string
|
|
Password string
|
|
Database string
|
|
AppName string
|
|
|
|
// Timeout for socket reads. If reached, commands will fail
|
|
// with a timeout instead of blocking.
|
|
ReadTimeout time.Duration
|
|
// Timeout for socket writes. If reached, commands will fail
|
|
// with a timeout instead of blocking.
|
|
WriteTimeout time.Duration
|
|
}
|
|
|
|
func newDefaultConfig() *Config {
|
|
host := env("PGHOST", "localhost")
|
|
port := env("PGPORT", "5432")
|
|
|
|
cfg := &Config{
|
|
Network: "tcp",
|
|
Addr: net.JoinHostPort(host, port),
|
|
DialTimeout: 5 * time.Second,
|
|
|
|
User: env("PGUSER", "postgres"),
|
|
Database: env("PGDATABASE", "postgres"),
|
|
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 5 * time.Second,
|
|
}
|
|
|
|
cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
netDialer := &net.Dialer{
|
|
Timeout: cfg.DialTimeout,
|
|
KeepAlive: 5 * time.Minute,
|
|
}
|
|
return netDialer.DialContext(ctx, network, addr)
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
type DriverOption func(*Connector)
|
|
|
|
func WithAddr(addr string) DriverOption {
|
|
if addr == "" {
|
|
panic("addr is empty")
|
|
}
|
|
return func(d *Connector) {
|
|
d.cfg.Addr = addr
|
|
}
|
|
}
|
|
|
|
func WithTLSConfig(cfg *tls.Config) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.TLSConfig = cfg
|
|
}
|
|
}
|
|
|
|
func WithUser(user string) DriverOption {
|
|
if user == "" {
|
|
panic("user is empty")
|
|
}
|
|
return func(d *Connector) {
|
|
d.cfg.User = user
|
|
}
|
|
}
|
|
|
|
func WithPassword(password string) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.Password = password
|
|
}
|
|
}
|
|
|
|
func WithDatabase(database string) DriverOption {
|
|
if database == "" {
|
|
panic("database is empty")
|
|
}
|
|
return func(d *Connector) {
|
|
d.cfg.Database = database
|
|
}
|
|
}
|
|
|
|
func WithApplicationName(appName string) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.AppName = appName
|
|
}
|
|
}
|
|
|
|
func WithTimeout(timeout time.Duration) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.DialTimeout = timeout
|
|
d.cfg.ReadTimeout = timeout
|
|
d.cfg.WriteTimeout = timeout
|
|
}
|
|
}
|
|
|
|
func WithDialTimeout(dialTimeout time.Duration) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.DialTimeout = dialTimeout
|
|
}
|
|
}
|
|
|
|
func WithReadTimeout(readTimeout time.Duration) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.ReadTimeout = readTimeout
|
|
}
|
|
}
|
|
|
|
func WithWriteTimeout(writeTimeout time.Duration) DriverOption {
|
|
return func(d *Connector) {
|
|
d.cfg.WriteTimeout = writeTimeout
|
|
}
|
|
}
|
|
|
|
func WithDSN(dsn string) DriverOption {
|
|
return func(d *Connector) {
|
|
opts, err := parseDSN(dsn)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(d)
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseDSN(dsn string) ([]DriverOption, error) {
|
|
u, err := url.Parse(dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
|
|
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
|
|
}
|
|
|
|
query, err := url.ParseQuery(u.RawQuery)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var opts []DriverOption
|
|
|
|
if u.Host != "" {
|
|
addr := u.Host
|
|
if !strings.Contains(addr, ":") {
|
|
addr += ":5432"
|
|
}
|
|
opts = append(opts, WithAddr(addr))
|
|
}
|
|
if u.User != nil {
|
|
opts = append(opts, WithUser(u.User.Username()))
|
|
if password, ok := u.User.Password(); ok {
|
|
opts = append(opts, WithPassword(password))
|
|
}
|
|
}
|
|
if len(u.Path) > 1 {
|
|
opts = append(opts, WithDatabase(u.Path[1:]))
|
|
}
|
|
|
|
if appName := query.Get("application_name"); appName != "" {
|
|
opts = append(opts, WithApplicationName(appName))
|
|
}
|
|
delete(query, "application_name")
|
|
|
|
if sslMode := query.Get("sslmode"); sslMode != "" {
|
|
switch sslMode {
|
|
case "verify-ca", "verify-full":
|
|
opts = append(opts, WithTLSConfig(new(tls.Config)))
|
|
case "allow", "prefer", "require":
|
|
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
case "disable":
|
|
// no TLS config
|
|
default:
|
|
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
|
|
}
|
|
} else {
|
|
opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
}
|
|
delete(query, "sslmode")
|
|
|
|
for key := range query {
|
|
return nil, fmt.Errorf("pgdriver: unsupported option=%q", key)
|
|
}
|
|
|
|
return opts, nil
|
|
}
|
|
|
|
func env(key, defValue string) string {
|
|
if s := os.Getenv(key); s != "" {
|
|
return s
|
|
}
|
|
return defValue
|
|
}
|
|
|
|
// verify is a method to make sure if the config is legitimate
|
|
// in the case it detects any errors, it returns with a non-nil error
|
|
// it can be extended to check other parameters
|
|
func (c *Config) verify() error {
|
|
if c.User == "" {
|
|
return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
|
|
}
|
|
return nil
|
|
}
|