313 lines
8.7 KiB
Go
313 lines
8.7 KiB
Go
package pq
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/lib/pq/internal/pqutil"
|
|
)
|
|
|
|
// Registry for custom tls.Configs
|
|
var (
|
|
tlsConfs = make(map[string]*tls.Config)
|
|
tlsConfsMu sync.RWMutex
|
|
)
|
|
|
|
// RegisterTLSConfig registers a custom [tls.Config]. They are used by using
|
|
// sslmode=pqgo-«key» in the connection string.
|
|
//
|
|
// Set the config to nil to remove a configuration.
|
|
func RegisterTLSConfig(key string, config *tls.Config) error {
|
|
key = strings.TrimPrefix(key, "pqgo-")
|
|
if config == nil {
|
|
tlsConfsMu.Lock()
|
|
delete(tlsConfs, key)
|
|
tlsConfsMu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
tlsConfsMu.Lock()
|
|
tlsConfs[key] = config
|
|
tlsConfsMu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func hasTLSConfig(key string) bool {
|
|
tlsConfsMu.RLock()
|
|
defer tlsConfsMu.RUnlock()
|
|
_, ok := tlsConfs[key]
|
|
return ok
|
|
}
|
|
|
|
func getTLSConfigClone(key string) *tls.Config {
|
|
tlsConfsMu.RLock()
|
|
defer tlsConfsMu.RUnlock()
|
|
if v, ok := tlsConfs[key]; ok {
|
|
return v.Clone()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
|
|
// related settings. The function is nil when no upgrade should take place.
|
|
//
|
|
// Don't refer to Config.SSLMode here, as the mode in arguments may be different
|
|
// in case of sslmode=allow or prefer.
|
|
func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) {
|
|
var (
|
|
home = pqutil.Home()
|
|
// Don't set defaults here, because tlsConf may be overwritten if a
|
|
// custom one was registered. Set it after the sslmode switch.
|
|
tlsConf = &tls.Config{}
|
|
// Only verify the CA signing but not the hostname.
|
|
verifyCaOnly = false
|
|
)
|
|
if mode.useSSL() && !cfg.SSLInline && cfg.SSLRootCert == "" && home != "" {
|
|
f := filepath.Join(home, "root.crt")
|
|
if _, err := os.Stat(f); err == nil {
|
|
cfg.SSLRootCert = f
|
|
}
|
|
}
|
|
switch {
|
|
case mode == SSLModeDisable || mode == SSLModeAllow:
|
|
return nil, nil
|
|
|
|
case mode == SSLModeRequire || mode == SSLModePrefer:
|
|
// Skip TLS's own verification since it requires full verification.
|
|
tlsConf.InsecureSkipVerify = true
|
|
|
|
// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
|
|
//
|
|
// For backwards compatibility with earlier versions of PostgreSQL, if a
|
|
// root CA file exists, the behavior of sslmode=require will be the same
|
|
// as that of verify-ca, meaning the server certificate is validated
|
|
// against the CA. Relying on this behavior is discouraged, and
|
|
// applications that need certificate validation should always use
|
|
// verify-ca or verify-full.
|
|
if cfg.SSLRootCert != "" {
|
|
if cfg.SSLInline {
|
|
verifyCaOnly = true
|
|
} else if _, err := os.Stat(cfg.SSLRootCert); err == nil {
|
|
verifyCaOnly = true
|
|
} else if cfg.SSLRootCert != "system" {
|
|
cfg.SSLRootCert = ""
|
|
}
|
|
}
|
|
case mode == SSLModeVerifyCA:
|
|
// Skip TLS's own verification since it requires full verification.
|
|
tlsConf.InsecureSkipVerify = true
|
|
verifyCaOnly = true
|
|
case mode == SSLModeVerifyFull:
|
|
tlsConf.ServerName = cfg.Host
|
|
case strings.HasPrefix(string(mode), "pqgo-"):
|
|
tlsConf = getTLSConfigClone(string(mode[5:]))
|
|
if tlsConf == nil {
|
|
return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode)
|
|
}
|
|
default:
|
|
panic("unreachable")
|
|
}
|
|
|
|
tlsConf.MinVersion = cfg.SSLMinProtocolVersion.tlsconf()
|
|
tlsConf.MaxVersion = cfg.SSLMaxProtocolVersion.tlsconf()
|
|
|
|
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or
|
|
// IPv6). This check is coded already crypto.tls.hostnameInSNI, so just
|
|
// always set ServerName here and let crypto/tls do the filtering.
|
|
if cfg.SSLSNI {
|
|
tlsConf.ServerName = cfg.Host
|
|
}
|
|
|
|
err := sslClientCertificates(tlsConf, cfg, home)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rootPem, err := sslCertificateAuthority(tlsConf, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sslAppendIntermediates(tlsConf, cfg, rootPem)
|
|
|
|
// Accept renegotiation requests initiated by the backend.
|
|
//
|
|
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but the
|
|
// default configuration of older versions has it enabled. Redshift also
|
|
// initiates renegotiations and cannot be reconfigured.
|
|
//
|
|
// TODO: I think this can be removed?
|
|
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
|
|
|
|
return func(conn net.Conn) (net.Conn, error) {
|
|
client := tls.Client(conn, tlsConf)
|
|
if verifyCaOnly {
|
|
err := client.Handshake()
|
|
if err != nil {
|
|
return client, err
|
|
}
|
|
var (
|
|
certs = client.ConnectionState().PeerCertificates
|
|
opts = x509.VerifyOptions{Intermediates: x509.NewCertPool(), Roots: tlsConf.RootCAs}
|
|
)
|
|
for _, cert := range certs[1:] {
|
|
opts.Intermediates.AddCert(cert)
|
|
}
|
|
_, err = certs[0].Verify(opts)
|
|
return client, err
|
|
}
|
|
return client, nil
|
|
}, nil
|
|
}
|
|
|
|
// sslClientCertificates adds the certificate specified in the "sslcert" and
|
|
//
|
|
// "sslkey" settings, or if they aren't set, from the .postgresql directory
|
|
// in the user's home directory. The configured files must exist and have
|
|
// the correct permissions.
|
|
func sslClientCertificates(tlsConf *tls.Config, cfg Config, home string) error {
|
|
if cfg.SSLInline {
|
|
cert, err := tls.X509KeyPair([]byte(cfg.SSLCert), []byte(cfg.SSLKey))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Use GetClientCertificate instead of the Certificates field. When
|
|
// Certificates is set, Go's TLS client only sends the cert if the
|
|
// server's CertificateRequest includes a CA that issued it. When the
|
|
// client cert was signed by an intermediate CA but the server only
|
|
// advertises the root CA, Go skips sending the cert entirely.
|
|
// GetClientCertificate bypasses this filtering.
|
|
tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return &cert, nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Only load client certificate and key if the setting is not blank, like libpq.
|
|
if cfg.SSLCert == "" && home != "" {
|
|
cfg.SSLCert = filepath.Join(home, "postgresql.crt")
|
|
}
|
|
if cfg.SSLCert == "" {
|
|
return nil
|
|
}
|
|
_, err := os.Stat(cfg.SSLCert)
|
|
if err != nil {
|
|
if pqutil.ErrNotExists(err) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// In libpq, the ssl key is only loaded if the setting is not blank.
|
|
if cfg.SSLKey == "" && home != "" {
|
|
cfg.SSLKey = filepath.Join(home, "postgresql.key")
|
|
}
|
|
if cfg.SSLKey != "" {
|
|
err := pqutil.SSLKeyPermissions(cfg.SSLKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Using GetClientCertificate instead of Certificates per comment above.
|
|
tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
return &cert, nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var testSystemRoots *x509.CertPool
|
|
|
|
// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting.
|
|
func sslCertificateAuthority(tlsConf *tls.Config, cfg Config) ([]byte, error) {
|
|
// Only load root certificate if not blank, like libpq.
|
|
if cfg.SSLRootCert == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
if cfg.SSLRootCert == "system" {
|
|
// No work to do as system CAs are used by default if RootCAs is nil.
|
|
tlsConf.RootCAs = testSystemRoots
|
|
return nil, nil
|
|
}
|
|
|
|
tlsConf.RootCAs = x509.NewCertPool()
|
|
|
|
var cert []byte
|
|
if cfg.SSLInline {
|
|
cert = []byte(cfg.SSLRootCert)
|
|
} else {
|
|
var err error
|
|
cert, err = os.ReadFile(cfg.SSLRootCert)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if !tlsConf.RootCAs.AppendCertsFromPEM(cert) {
|
|
return nil, errors.New("pq: couldn't parse pem from sslrootcert")
|
|
}
|
|
return cert, nil
|
|
}
|
|
|
|
// sslAppendIntermediates appends intermediate CA certificates from sslrootcert
|
|
// to the client certificate chain. This is needed so the server can verify the
|
|
// client cert when it was signed by an intermediate CA — without this, the TLS
|
|
// handshake only sends the leaf client cert.
|
|
func sslAppendIntermediates(tlsConf *tls.Config, cfg Config, rootPem []byte) {
|
|
if cfg.SSLRootCert == "" || tlsConf.GetClientCertificate == nil || len(rootPem) == 0 {
|
|
return
|
|
}
|
|
|
|
var (
|
|
pemData = slices.Clone(rootPem)
|
|
intermediates [][]byte
|
|
)
|
|
for {
|
|
var block *pem.Block
|
|
block, pemData = pem.Decode(pemData)
|
|
if block == nil {
|
|
break
|
|
}
|
|
if block.Type != "CERTIFICATE" {
|
|
continue
|
|
}
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
// Skip self-signed root CAs; only append intermediates.
|
|
if cert.IsCA && !bytes.Equal(cert.RawIssuer, cert.RawSubject) {
|
|
intermediates = append(intermediates, block.Bytes)
|
|
}
|
|
}
|
|
if len(intermediates) == 0 {
|
|
return
|
|
}
|
|
|
|
// Wrap the existing GetClientCertificate to append intermediate certs to
|
|
// the certificate chain returned during the TLS handshake.
|
|
origGetCert := tlsConf.GetClientCertificate
|
|
tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
|
cert, err := origGetCert(info)
|
|
if err != nil {
|
|
return cert, err
|
|
}
|
|
cert.Certificate = append(cert.Certificate, intermediates...)
|
|
return cert, nil
|
|
}
|
|
}
|