mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-16 14:03:00 -06:00
add tls support for db connection
This commit is contained in:
parent
7a072749c4
commit
66ea26ced2
5 changed files with 134 additions and 37 deletions
|
|
@ -117,6 +117,18 @@ func main() {
|
||||||
Value: defaults.DbDatabase,
|
Value: defaults.DbDatabase,
|
||||||
EnvVars: []string{envNames.DbDatabase},
|
EnvVars: []string{envNames.DbDatabase},
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: flagNames.DbTLSMode,
|
||||||
|
Usage: "Database tls mode",
|
||||||
|
Value: defaults.DBTlsMode,
|
||||||
|
EnvVars: []string{envNames.DbTLSMode},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: flagNames.DbTLSCACert,
|
||||||
|
Usage: "Path to CA cert for db tls connection",
|
||||||
|
Value: defaults.DBTlsCACert,
|
||||||
|
EnvVars: []string{envNames.DbTLSCACert},
|
||||||
|
},
|
||||||
|
|
||||||
// TEMPLATE FLAGS
|
// TEMPLATE FLAGS
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error {
|
||||||
c.DBConfig.Database = f.String(fn.DbDatabase)
|
c.DBConfig.Database = f.String(fn.DbDatabase)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.DBConfig.TLSMode == DBTLSModeUnset || f.IsSet(fn.DbTLSMode) {
|
||||||
|
c.DBConfig.TLSMode = DBTLSMode(f.String(fn.DbTLSMode))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.DBConfig.TLSCACert == "" || f.IsSet(fn.DbTLSCACert) {
|
||||||
|
c.DBConfig.TLSCACert = f.String(fn.DbTLSCACert)
|
||||||
|
}
|
||||||
|
|
||||||
// template flags
|
// template flags
|
||||||
if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) {
|
if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) {
|
||||||
c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir)
|
c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir)
|
||||||
|
|
@ -284,12 +292,14 @@ type Flags struct {
|
||||||
Host string
|
Host string
|
||||||
Protocol string
|
Protocol string
|
||||||
|
|
||||||
DbType string
|
DbType string
|
||||||
DbAddress string
|
DbAddress string
|
||||||
DbPort string
|
DbPort string
|
||||||
DbUser string
|
DbUser string
|
||||||
DbPassword string
|
DbPassword string
|
||||||
DbDatabase string
|
DbDatabase string
|
||||||
|
DbTLSMode string
|
||||||
|
DbTLSCACert string
|
||||||
|
|
||||||
TemplateBaseDir string
|
TemplateBaseDir string
|
||||||
AssetBaseDir string
|
AssetBaseDir string
|
||||||
|
|
@ -329,12 +339,14 @@ type Defaults struct {
|
||||||
Protocol string
|
Protocol string
|
||||||
SoftwareVersion string
|
SoftwareVersion string
|
||||||
|
|
||||||
DbType string
|
DbType string
|
||||||
DbAddress string
|
DbAddress string
|
||||||
DbPort int
|
DbPort int
|
||||||
DbUser string
|
DbUser string
|
||||||
DbPassword string
|
DbPassword string
|
||||||
DbDatabase string
|
DbDatabase string
|
||||||
|
DBTlsMode string
|
||||||
|
DBTlsCACert string
|
||||||
|
|
||||||
TemplateBaseDir string
|
TemplateBaseDir string
|
||||||
AssetBaseDir string
|
AssetBaseDir string
|
||||||
|
|
@ -375,12 +387,14 @@ func GetFlagNames() Flags {
|
||||||
Host: "host",
|
Host: "host",
|
||||||
Protocol: "protocol",
|
Protocol: "protocol",
|
||||||
|
|
||||||
DbType: "db-type",
|
DbType: "db-type",
|
||||||
DbAddress: "db-address",
|
DbAddress: "db-address",
|
||||||
DbPort: "db-port",
|
DbPort: "db-port",
|
||||||
DbUser: "db-user",
|
DbUser: "db-user",
|
||||||
DbPassword: "db-password",
|
DbPassword: "db-password",
|
||||||
DbDatabase: "db-database",
|
DbDatabase: "db-database",
|
||||||
|
DbTLSMode: "db-tls-mode",
|
||||||
|
DbTLSCACert: "db-tls-ca-cert",
|
||||||
|
|
||||||
TemplateBaseDir: "template-basedir",
|
TemplateBaseDir: "template-basedir",
|
||||||
AssetBaseDir: "asset-basedir",
|
AssetBaseDir: "asset-basedir",
|
||||||
|
|
@ -422,12 +436,14 @@ func GetEnvNames() Flags {
|
||||||
Host: "GTS_HOST",
|
Host: "GTS_HOST",
|
||||||
Protocol: "GTS_PROTOCOL",
|
Protocol: "GTS_PROTOCOL",
|
||||||
|
|
||||||
DbType: "GTS_DB_TYPE",
|
DbType: "GTS_DB_TYPE",
|
||||||
DbAddress: "GTS_DB_ADDRESS",
|
DbAddress: "GTS_DB_ADDRESS",
|
||||||
DbPort: "GTS_DB_PORT",
|
DbPort: "GTS_DB_PORT",
|
||||||
DbUser: "GTS_DB_USER",
|
DbUser: "GTS_DB_USER",
|
||||||
DbPassword: "GTS_DB_PASSWORD",
|
DbPassword: "GTS_DB_PASSWORD",
|
||||||
DbDatabase: "GTS_DB_DATABASE",
|
DbDatabase: "GTS_DB_DATABASE",
|
||||||
|
DbTLSMode: "GTS_DB_TLS_MODE",
|
||||||
|
DbTLSCACert: "GTS_DB_CA_CERT",
|
||||||
|
|
||||||
TemplateBaseDir: "GTS_TEMPLATE_BASEDIR",
|
TemplateBaseDir: "GTS_TEMPLATE_BASEDIR",
|
||||||
AssetBaseDir: "GTS_ASSET_BASEDIR",
|
AssetBaseDir: "GTS_ASSET_BASEDIR",
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,30 @@ package config
|
||||||
|
|
||||||
// DBConfig provides configuration options for the database connection
|
// DBConfig provides configuration options for the database connection
|
||||||
type DBConfig struct {
|
type DBConfig struct {
|
||||||
Type string `yaml:"type"`
|
Type string `yaml:"type"`
|
||||||
Address string `yaml:"address"`
|
Address string `yaml:"address"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
User string `yaml:"user"`
|
User string `yaml:"user"`
|
||||||
Password string `yaml:"password"`
|
Password string `yaml:"password"`
|
||||||
Database string `yaml:"database"`
|
Database string `yaml:"database"`
|
||||||
ApplicationName string `yaml:"applicationName"`
|
ApplicationName string `yaml:"applicationName"`
|
||||||
|
TLSMode DBTLSMode `yaml:"tlsMode"`
|
||||||
|
TLSCACert string `yaml:"tlsCACert"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DBTLSMode describes a mode of connecting to a database with or without TLS.
|
||||||
|
type DBTLSMode string
|
||||||
|
|
||||||
|
// DBTLSModeDisable does not attempt to make a TLS connection to the database.
|
||||||
|
var DBTLSModeDisable DBTLSMode = "disable"
|
||||||
|
|
||||||
|
// DBTLSModeEnable attempts to make a TLS connection to the database, but doesn't fail if
|
||||||
|
// the certificate passed by the database isn't verified.
|
||||||
|
var DBTLSModeEnable DBTLSMode = "enable"
|
||||||
|
|
||||||
|
// DBTLSModeRequire attempts to make a TLS connection to the database, and requires
|
||||||
|
// that the certificate presented by the database is valid.
|
||||||
|
var DBTLSModeRequire DBTLSMode = "require"
|
||||||
|
|
||||||
|
// DBTLSModeUnset means that the TLS mode has not been set.
|
||||||
|
var DBTLSModeUnset DBTLSMode = ""
|
||||||
|
|
|
||||||
|
|
@ -120,12 +120,14 @@ func GetDefaults() Defaults {
|
||||||
Host: "",
|
Host: "",
|
||||||
Protocol: "https",
|
Protocol: "https",
|
||||||
|
|
||||||
DbType: "postgres",
|
DbType: "postgres",
|
||||||
DbAddress: "localhost",
|
DbAddress: "localhost",
|
||||||
DbPort: 5432,
|
DbPort: 5432,
|
||||||
DbUser: "postgres",
|
DbUser: "postgres",
|
||||||
DbPassword: "postgres",
|
DbPassword: "postgres",
|
||||||
DbDatabase: "postgres",
|
DbDatabase: "postgres",
|
||||||
|
DBTlsMode: "disable",
|
||||||
|
DBTlsCACert: "",
|
||||||
|
|
||||||
TemplateBaseDir: "./web/template/",
|
TemplateBaseDir: "./web/template/",
|
||||||
AssetBaseDir: "./web/assets/",
|
AssetBaseDir: "./web/assets/",
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -133,6 +137,49 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
||||||
return nil, errors.New("no database set")
|
return nil, errors.New("no database set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
switch c.DBConfig.TLSMode {
|
||||||
|
case config.DBTLSModeDisable, config.DBTLSModeUnset:
|
||||||
|
break // nothing to do
|
||||||
|
case config.DBTLSModeEnable:
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
case config.DBTLSModeRequire:
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
InsecureSkipVerify: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
|
||||||
|
// load the system cert pool first -- we'll append the given CA cert to this
|
||||||
|
certPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(caCertBytes) == 0 {
|
||||||
|
return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
|
||||||
|
}
|
||||||
|
|
||||||
|
caPem, _ := pem.Decode(caCertBytes)
|
||||||
|
if caPem == nil {
|
||||||
|
return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
|
||||||
|
}
|
||||||
|
|
||||||
|
caCert, err := x509.ParseCertificate(caPem.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPool.AddCert(caCert)
|
||||||
|
}
|
||||||
|
|
||||||
// We can rely on the pg library we're using to set
|
// We can rely on the pg library we're using to set
|
||||||
// sensible defaults for everything we don't set here.
|
// sensible defaults for everything we don't set here.
|
||||||
options := &pg.Options{
|
options := &pg.Options{
|
||||||
|
|
@ -141,6 +188,7 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
||||||
Password: c.DBConfig.Password,
|
Password: c.DBConfig.Password,
|
||||||
Database: c.DBConfig.Database,
|
Database: c.DBConfig.Database,
|
||||||
ApplicationName: c.ApplicationName,
|
ApplicationName: c.ApplicationName,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
return options, nil
|
return options, nil
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue