package sql import ( "bytes" "context" "crypto/tls" "crypto/x509" "database/sql" "errors" "fmt" "os" "strconv" "strings" "text/template" "time" "github.com/go-logr/logr" "github.com/go-sql-driver/mysql" mariadbv1alpha1 "github.com/mariadb-operator/mariadb-operator/v25/api/v1alpha1" "github.com/mariadb-operator/mariadb-operator/v25/pkg/environment" "github.com/mariadb-operator/mariadb-operator/v25/pkg/interfaces" "github.com/mariadb-operator/mariadb-operator/v25/pkg/pki" "github.com/mariadb-operator/mariadb-operator/v25/pkg/refresolver" "github.com/mariadb-operator/mariadb-operator/v25/pkg/statefulset" "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" ) var ( ErrWaitReplicaTimeout = errors.New("timeout waiting for replica to be synced") ) type Opts struct { Username string Password string Host string Port int32 Database string MariadbName string MaxscaleName string ExternalMariadbName string Namespace string ClientName string TLSCACert []byte TLSClientCert []byte TLSClientPrivateKey []byte Params map[string]string Timeout *time.Duration } type Opt func(*Opts) func WithUsername(username string) Opt { return func(o *Opts) { o.Username = username } } func WithPassword(password string) Opt { return func(o *Opts) { o.Password = password } } func WitHost(host string) Opt { return func(o *Opts) { o.Host = host } } func WithPort(port int32) Opt { return func(o *Opts) { o.Port = port } } func WithDatabase(database string) Opt { return func(o *Opts) { o.Database = database } } func WithMariadbTLS(name, namespace string, tlsCaCert []byte) Opt { return func(o *Opts) { o.MariadbName = name o.Namespace = namespace o.TLSCACert = tlsCaCert } } func WithMaxscaleTLS(name, namespace string, tlsCaCert []byte) Opt { return func(o *Opts) { o.MaxscaleName = name o.Namespace = namespace o.TLSCACert = tlsCaCert } } func WithTLSClientCert(clientName string, cert, privateKey []byte) Opt { return func(o *Opts) { o.ClientName = clientName o.TLSClientCert = cert o.TLSClientPrivateKey = privateKey } } func WithParams(params map[string]string) Opt { return func(o *Opts) { o.Params = params } } func WithTimeout(d time.Duration) Opt { return func(o *Opts) { o.Timeout = &d } } type Client struct { db *sql.DB } func NewClient(clientOpts ...Opt) (*Client, error) { opts := Opts{} for _, setOpt := range clientOpts { setOpt(&opts) } dsn, err := BuildDSN(opts) if err != nil { return nil, fmt.Errorf("error building DSN: %v", err) } db, err := Connect(dsn) if err != nil { return nil, err } return &Client{ db: db, }, nil } func NewClientWithMariaDB(ctx context.Context, mariadb interfaces.MariaDBObject, refResolver *refresolver.RefResolver, clientOpts ...Opt) (*Client, error) { if mariadb.GetSUCredential() == nil { return nil, fmt.Errorf("error: superuser credential is nil") } password, err := refResolver.SecretKeyRef(ctx, *mariadb.GetSUCredential(), mariadb.GetNamespace()) var opts []Opt if err != nil { return nil, fmt.Errorf("error reading root password secret: %v", err) } opts = []Opt{ WithUsername(mariadb.GetSUName()), WithPassword(password), WitHost(mariadb.GetHost()), WithPort(mariadb.GetPort()), } if mariadb.IsTLSEnabled() { caCert, err := refResolver.SecretKeyRef(ctx, mariadb.TLSCABundleSecretKeyRef(), mariadb.GetNamespace()) if err != nil { return nil, fmt.Errorf("error getting CA certificate: %v", err) } opts = append(opts, WithMariadbTLS(mariadb.GetName(), mariadb.GetNamespace(), []byte(caCert))) clientSecretKey := types.NamespacedName{ Name: mariadb.TLSClientCertSecretKey().Name, Namespace: mariadb.GetNamespace(), } clientCertSelector := mariadbv1alpha1.SecretKeySelector{ LocalObjectReference: mariadbv1alpha1.LocalObjectReference{ Name: clientSecretKey.Name, }, Key: pki.TLSCertKey, } clientCert, err := refResolver.SecretKeyRef(ctx, clientCertSelector, clientSecretKey.Namespace) if err != nil { return nil, fmt.Errorf("error getting client certificate: %v", err) } clientPrivateKeySelector := mariadbv1alpha1.SecretKeySelector{ LocalObjectReference: mariadbv1alpha1.LocalObjectReference{ Name: clientSecretKey.Name, }, Key: pki.TLSKeyKey, } clientPrivateKey, err := refResolver.SecretKeyRef(ctx, clientPrivateKeySelector, clientSecretKey.Namespace) if err != nil { return nil, fmt.Errorf("error getting client private key: %v", err) } opts = append(opts, WithTLSClientCert(clientCertSelector.Name, []byte(clientCert), []byte(clientPrivateKey))) } opts = append(opts, clientOpts...) return NewClient(opts...) } func NewInternalClientWithPodIndex(ctx context.Context, mariadb *mariadbv1alpha1.MariaDB, refResolver *refresolver.RefResolver, podIndex int, clientOpts ...Opt) (*Client, error) { opts := []Opt{ WitHost( statefulset.PodFQDNWithService( mariadb.ObjectMeta, podIndex, mariadb.InternalServiceKey().Name, ), ), } opts = append(opts, clientOpts...) return NewClientWithMariaDB(ctx, mariadb, refResolver, opts...) } func NewLocalClientWithPodEnv(ctx context.Context, env *environment.PodEnvironment, clientOpts ...Opt) (*Client, error) { port, err := env.Port() if err != nil { return nil, fmt.Errorf("error getting port: %v", err) } opts := []Opt{ WithUsername("root"), WithPassword(env.MariadbRootPassword), WitHost("localhost"), WithPort(port), } isTLSEnabled, err := env.IsTLSEnabled() if err != nil { return nil, fmt.Errorf("error checking whether TLS is enabled in environment: %v", err) } if isTLSEnabled { caCert, err := os.ReadFile(env.TLSCACertPath) if err != nil { return nil, fmt.Errorf("error reading CA certificate: %v", err) } opts = append(opts, WithMariadbTLS(env.MariadbName, env.PodNamespace, caCert)) } opts = append(opts, clientOpts...) return NewClient(opts...) } func BuildDSN(opts Opts) (string, error) { if opts.Host == "" || opts.Port == 0 { return "", errors.New("invalid opts: host and port are mandatory") } config := mysql.NewConfig() config.Net = "tcp" config.Addr = fmt.Sprintf("%s:%d", opts.Host, opts.Port) if opts.Timeout != nil { config.Timeout = *opts.Timeout } else { config.Timeout = 5 * time.Second } if opts.Username != "" { config.User = opts.Username } if opts.Password != "" { config.Passwd = opts.Password } if opts.Database != "" { config.DBName = opts.Database } if opts.Params != nil { config.Params = opts.Params } if (opts.MariadbName != "" || opts.MaxscaleName != "" || opts.ExternalMariadbName != "") && opts.Namespace != "" && opts.TLSCACert != nil { configName, err := configureTLS(opts) if err != nil { return "", fmt.Errorf("error configuring TLS: %v", err) } config.TLSConfig = configName } return config.FormatDSN(), nil } func configureTLS(opts Opts) (string, error) { configName, err := configTLSName(opts) if err != nil { return "", fmt.Errorf("error getting TLS config name: %v", err) } var tlsCfg tls.Config caBundle := x509.NewCertPool() if ok := caBundle.AppendCertsFromPEM(opts.TLSCACert); ok { tlsCfg.RootCAs = caBundle } else { return "", errors.New("failed parse pem-encoded CA certificates") } if opts.TLSClientCert != nil && opts.TLSClientPrivateKey != nil { keyPair, err := tls.X509KeyPair(opts.TLSClientCert, opts.TLSClientPrivateKey) if err != nil { return "", fmt.Errorf("error parsing client keypair: %v", err) } tlsCfg.Certificates = []tls.Certificate{keyPair} } if err := mysql.RegisterTLSConfig(configName, &tlsCfg); err != nil { return "", fmt.Errorf("error registering TLS config \"%s\": %v", configName, err) } return configName, nil } func configTLSName(opts Opts) (string, error) { var configName string if opts.MariadbName != "" { configName = fmt.Sprintf("mariadb-%s-%s", opts.MariadbName, opts.Namespace) } else if opts.MaxscaleName != "" { configName = fmt.Sprintf("maxscale-%s-%s", opts.MaxscaleName, opts.Namespace) } else if opts.ExternalMariadbName != "" { configName = fmt.Sprintf("mariadb-%s-%s", opts.ExternalMariadbName, opts.Namespace) } else { return "", errors.New("unable to create config name: either MariaDB or MaxScale names must be set") } if opts.ClientName != "" { configName += fmt.Sprintf("-client-%s", opts.ClientName) } return configName, nil } func Connect(dsn string) (*sql.DB, error) { db, err := sql.Open("mysql", dsn) if err != nil { return nil, err } if err := db.PingContext(context.Background()); err != nil { db.Close() return nil, err } return db, nil } func ConnectWithOpts(opts Opts) (*sql.DB, error) { dsn, err := BuildDSN(opts) if err != nil { return nil, fmt.Errorf("error building DNS: %v", err) } return Connect(dsn) } func (c *Client) Close() error { return c.db.Close() } func (c *Client) Exec(ctx context.Context, sql string, args ...any) error { _, err := c.db.ExecContext(ctx, sql, args...) return err } func (c Client) Exists(ctx context.Context, sql string, args ...any) (bool, error) { rows, err := c.db.QueryContext(ctx, sql, args...) if err != nil { return false, err } defer rows.Close() return rows.Next(), nil } // QueryColumnMaps executes a query and returns all rows as []map[column]value. func (c Client) QueryColumnMaps(ctx context.Context, sql string) ([]map[string]string, error) { rows, err := c.db.QueryContext(ctx, sql) if err != nil { return nil, err } defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, err } var results []map[string]string for rows.Next() { raw := make([]interface{}, len(columns)) dest := make([]interface{}, len(columns)) for i := range raw { dest[i] = &raw[i] } if err := rows.Scan(dest...); err != nil { return nil, err } row := make(map[string]string, len(columns)) for i, col := range columns { row[col] = toString(raw[i]) } results = append(results, row) } if err := rows.Err(); err != nil { return nil, err } return results, nil } // QueryColumnMap invokes QueryColumnMaps and returns the first row. func (c Client) QueryColumnMap(ctx context.Context, sql string) (map[string]string, error) { rows, err := c.QueryColumnMaps(ctx, sql) if err != nil { return nil, err } if len(rows) == 0 { return nil, fmt.Errorf("no rows returned for query %q", sql) } return rows[0], nil } type CreateUserOpts struct { IdentifiedBy string IdentifiedByPassword string IdentifiedVia string IdentifiedViaUsing string Require *mariadbv1alpha1.TLSRequirements MaxUserConnections int32 } type CreateUserOpt func(*CreateUserOpts) func WithIdentifiedBy(password string) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.IdentifiedBy = password } } func WithIdentifiedByPassword(password string) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.IdentifiedByPassword = password } } func WithIdentifiedVia(via string) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.IdentifiedVia = via } } func WithIdentifiedViaUsing(viaUsing string) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.IdentifiedViaUsing = viaUsing } } func WithTLSRequirements(require *mariadbv1alpha1.TLSRequirements) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.Require = require } } func WithMaxUserConnections(maxConns int32) CreateUserOpt { return func(cuo *CreateUserOpts) { cuo.MaxUserConnections = maxConns } } func (c *Client) CreateUser(ctx context.Context, accountName string, createUserOpts ...CreateUserOpt) error { opts := CreateUserOpts{} for _, setOpt := range createUserOpts { setOpt(&opts) } query := fmt.Sprintf("CREATE USER IF NOT EXISTS %s ", accountName) if opts.IdentifiedVia != "" { query += fmt.Sprintf("IDENTIFIED VIA %s ", opts.IdentifiedVia) if opts.IdentifiedViaUsing != "" { query += fmt.Sprintf("USING '%s' ", opts.IdentifiedViaUsing) } } else if opts.IdentifiedByPassword != "" { query += fmt.Sprintf("IDENTIFIED BY PASSWORD '%s' ", opts.IdentifiedByPassword) } else if opts.IdentifiedBy != "" { query += fmt.Sprintf("IDENTIFIED BY '%s' ", opts.IdentifiedBy) } if require := opts.Require; require != nil { requireSubQuery, err := requireQuery(require) if err != nil { return fmt.Errorf("error processing require subquery: %v", err) } query += fmt.Sprintf("%s ", requireSubQuery) } query += fmt.Sprintf("WITH MAX_USER_CONNECTIONS %d ", opts.MaxUserConnections) if opts.IdentifiedBy == "" && opts.IdentifiedByPassword == "" && opts.IdentifiedVia == "" && opts.Require == nil { query += "ACCOUNT LOCK PASSWORD EXPIRE " } query += ";" return c.Exec(ctx, query) } func (c *Client) DropUser(ctx context.Context, accountName string) error { query := fmt.Sprintf("DROP USER IF EXISTS %s;", accountName) return c.Exec(ctx, query) } func (c *Client) AlterUser(ctx context.Context, accountName string, createUserOpts ...CreateUserOpt) error { opts := CreateUserOpts{} for _, setOpt := range createUserOpts { setOpt(&opts) } query := fmt.Sprintf("ALTER USER %s ", accountName) if opts.IdentifiedVia != "" { query += fmt.Sprintf("IDENTIFIED VIA %s ", opts.IdentifiedVia) if opts.IdentifiedViaUsing != "" { query += fmt.Sprintf("USING '%s' ", opts.IdentifiedViaUsing) } } else if opts.IdentifiedByPassword != "" { query += fmt.Sprintf("IDENTIFIED BY PASSWORD '%s' ", opts.IdentifiedByPassword) } else if opts.IdentifiedBy != "" { query += fmt.Sprintf("IDENTIFIED BY '%s' ", opts.IdentifiedBy) } if require := opts.Require; require != nil { requireSubQuery, err := requireQuery(require) if err != nil { return fmt.Errorf("error processing require subquery: %v", err) } query += fmt.Sprintf("%s ", requireSubQuery) } query += fmt.Sprintf("WITH MAX_USER_CONNECTIONS %d ", opts.MaxUserConnections) query += ";" return c.Exec(ctx, query) } func (c *Client) UserExists(ctx context.Context, username, host string) (bool, error) { row := c.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM mysql.user WHERE user=? AND host=?", username, host) var count int if err := row.Scan(&count); err != nil { return false, err } return count > 0, nil } type grantOpts struct { grantOption bool } type GrantOption func(*grantOpts) func WithGrantOption() GrantOption { return func(o *grantOpts) { o.grantOption = true } } func (c *Client) Grant( ctx context.Context, privileges []string, database string, table string, accountName string, opts ...GrantOption, ) error { var grantOpts grantOpts for _, setOpt := range opts { setOpt(&grantOpts) } query := fmt.Sprintf("GRANT %s ON %s.%s TO %s ", strings.Join(privileges, ","), escapeWildcard(database), escapeWildcard(table), accountName, ) if grantOpts.grantOption { query += "WITH GRANT OPTION " } query += ";" return c.Exec(ctx, query) } func (c *Client) Revoke( ctx context.Context, privileges []string, database string, table string, accountName string, opts ...GrantOption, ) error { var grantOpts grantOpts for _, setOpt := range opts { setOpt(&grantOpts) } if grantOpts.grantOption { privileges = append(privileges, "GRANT OPTION") } query := fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privileges, ","), escapeWildcard(database), escapeWildcard(table), accountName, ) return c.Exec(ctx, query) } func escapeWildcard(s string) string { if s == "*" { return s } return fmt.Sprintf("`%s`", s) } type DatabaseOpts struct { CharacterSet string Collate string } func (c *Client) CreateDatabase(ctx context.Context, database string, opts DatabaseOpts) error { sql := fmt.Sprintf("SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s')", database) row := c.db.QueryRowContext(ctx, sql) var dbExists string if err := row.Scan(&dbExists); err != nil { return err } if dbExists == "1" { return nil } query := fmt.Sprintf("CREATE DATABASE `%s` ", database) if opts.CharacterSet != "" { query += fmt.Sprintf("CHARACTER SET = '%s' ", opts.CharacterSet) } if opts.Collate != "" { query += fmt.Sprintf("COLLATE = '%s' ", opts.Collate) } query += ";" return c.Exec(ctx, query) } func (c *Client) DropDatabase(ctx context.Context, database string) error { return c.Exec(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", database)) } func (c *Client) SystemVariable(ctx context.Context, variable string) (string, error) { sql := fmt.Sprintf("SELECT @@global.%s;", variable) row := c.db.QueryRowContext(ctx, sql) var val string if err := row.Scan(&val); err != nil { return "", nil } return val, nil } func (c *Client) IsSystemVariableEnabled(ctx context.Context, variable string) (bool, error) { val, err := c.SystemVariable(ctx, variable) if err != nil { return false, err } return val == "1" || val == "ON", nil } func (c *Client) SetSystemVariable(ctx context.Context, variable string, value string) error { sql := fmt.Sprintf("SET @@global.%s=%s;", variable, value) return c.Exec(ctx, sql) } func (c *Client) LockTablesWithReadLock(ctx context.Context) error { return c.Exec(ctx, "FLUSH TABLES WITH READ LOCK;") } func (c *Client) UnlockTables(ctx context.Context) error { return c.Exec(ctx, "UNLOCK TABLES;") } func (c *Client) EnableReadOnly(ctx context.Context) error { return c.SetSystemVariable(ctx, "read_only", "1") } func (c *Client) DisableReadOnly(ctx context.Context) error { return c.SetSystemVariable(ctx, "read_only", "0") } func (c *Client) ResetMaster(ctx context.Context) error { return c.Exec(ctx, "RESET MASTER;") } func (c *Client) StartSlave(ctx context.Context) error { return c.Exec(ctx, "START SLAVE;") } func (c *Client) StopAllSlaves(ctx context.Context) error { return c.Exec(ctx, "STOP ALL SLAVES;") } func (c *Client) ResetAllSlaves(ctx context.Context) error { return c.Exec(ctx, "RESET SLAVE ALL;") } func (c *Client) WaitForReplicaGtid(ctx context.Context, gtid string, timeout time.Duration) error { sql := fmt.Sprintf("SELECT MASTER_GTID_WAIT('%s', %d);", gtid, int(timeout.Seconds())) row := c.db.QueryRowContext(ctx, sql) var result int if err := row.Scan(&result); err != nil { return fmt.Errorf("error scanning result: %v", err) } switch result { case 0: return nil case -1: return ErrWaitReplicaTimeout default: return fmt.Errorf("unexpected result: %d", result) } } func (c *Client) GtidDomainId(ctx context.Context) (*uint32, error) { rawGtidDomainId, err := c.SystemVariable(ctx, "gtid_domain_id") if err != nil { return nil, err } gtidDomainId, err := strconv.ParseUint(rawGtidDomainId, 10, 32) if err != nil { return nil, fmt.Errorf("error parsing gtid_domain_id: %v", err) } return ptr.To(uint32(gtidDomainId)), nil } func (c *Client) GtidBinlogPos(ctx context.Context) (string, error) { return c.SystemVariable(ctx, "gtid_binlog_pos") } func (c *Client) GtidSlavePos(ctx context.Context) (string, error) { return c.SystemVariable(ctx, "gtid_slave_pos") } func (c *Client) GtidCurrentPos(ctx context.Context) (string, error) { return c.SystemVariable(ctx, "gtid_current_pos") } func (c *Client) SetGtidSlavePos(ctx context.Context, gtid string) error { if gtid == "" { return errors.New("gtid must not be empty") } return c.Exec(ctx, fmt.Sprintf("SET @@global.gtid_slave_pos='%s';", gtid)) } func (c *Client) ResetGtidSlavePos(ctx context.Context) error { return c.Exec(ctx, "SET @@global.gtid_slave_pos='';") } func (c Client) IsReplicationPrimary(ctx context.Context) (bool, error) { return c.Exists(ctx, "SHOW MASTER STATUS") } func (c Client) IsReplicationReplica(ctx context.Context) (bool, error) { return c.Exists(ctx, "SHOW REPLICA STATUS") } // See: https://mariadb.com/docs/server/reference/sql-statements/administrative-sql-statements/show/show-replica-status func (c Client) ReplicaStatus(ctx context.Context, logger logr.Logger) (*mariadbv1alpha1.ReplicaStatusVars, error) { row, err := c.QueryColumnMap(ctx, "SHOW REPLICA STATUS") if err != nil { return nil, fmt.Errorf("error getting replica status: %v", err) } status := mariadbv1alpha1.ReplicaStatusVars{} if lastIOErrno, ok := row["Last_IO_Errno"]; ok { errno, err := strconv.Atoi(lastIOErrno) if err != nil { logger.Error(err, "error parsing Last_IO_Errno") } else { status.LastIOErrno = ptr.To(errno) } } if lastIOError, ok := row["Last_IO_Error"]; ok { status.LastIOError = ptr.To(lastIOError) } if lastSQLErrno, ok := row["Last_SQL_Errno"]; ok { errno, err := strconv.Atoi(lastSQLErrno) if err != nil { logger.Error(err, "error parsing Last_SQL_Errno") } else { status.LastSQLErrno = ptr.To(errno) } } if lastSQLError, ok := row["Last_SQL_Error"]; ok { status.LastSQLError = ptr.To(lastSQLError) } if slaveIORunning, ok := row["Slave_IO_Running"]; ok { running, err := parseThreadRunning(slaveIORunning) if err != nil { logger.Error(err, "error parsing Slave_IO_Running") } else { status.SlaveIORunning = ptr.To(running) } } if slaveSQLRunning, ok := row["Slave_SQL_Running"]; ok { running, err := parseThreadRunning(slaveSQLRunning) if err != nil { logger.Error(err, "error parsing Slave_SQL_Running") } else { status.SlaveSQLRunning = ptr.To(running) } } // Seconds_Behind_Master may be empty when any of the replication threads are not running. Do not treat nil as 0! if secondsBehindMaster, ok := row["Seconds_Behind_Master"]; ok && secondsBehindMaster != "" { seconds, err := strconv.Atoi(secondsBehindMaster) if err != nil { logger.Error(err, "error parsing Seconds_Behind_Master") } else { status.SecondsBehindMaster = ptr.To(seconds) } } if gtidIOPos, ok := row["Gtid_IO_Pos"]; ok && gtidIOPos != "" { status.GtidIOPos = >idIOPos } gtidCurrentPos, err := c.GtidCurrentPos(ctx) if err != nil { return nil, fmt.Errorf("error getting gtid_current_pos: %v", err) } if gtidCurrentPos != "" { status.GtidCurrentPos = >idCurrentPos } return &status, nil } func (c Client) HasConnectedReplicas(ctx context.Context) (bool, error) { rows, err := c.QueryColumnMaps(ctx, "SHOW PROCESSLIST") if err != nil { return false, err } for _, row := range rows { cmd := row["Command"] if cmd == "Binlog Dump" || cmd == "Binlog Dump GTID" { return true, nil } } return false, nil } type ChangeMasterOpts struct { Host string Port int32 User string Password string Gtid string Retries int SSLEnabled bool SSLCertPath string SSLKeyPath string SSLCAPath string } type ChangeMasterOpt func(*ChangeMasterOpts) func WithChangeMasterHost(host string) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.Host = host } } func WithChangeMasterPort(port int32) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.Port = port } } func WithChangeMasterCredentials(user, password string) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.User = user cmo.Password = password } } func WithChangeMasterGtid(gtid string) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.Gtid = gtid } } func WithChangeMasterRetries(retries int) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.Retries = retries } } func WithChangeMasterSSL(certPath, keyPath, caPath string) ChangeMasterOpt { return func(cmo *ChangeMasterOpts) { cmo.SSLEnabled = true cmo.SSLCertPath = certPath cmo.SSLKeyPath = keyPath cmo.SSLCAPath = caPath } } func (c *Client) ChangeMaster(ctx context.Context, changeMasterOpts ...ChangeMasterOpt) error { query, err := buildChangeMasterQuery(changeMasterOpts...) if err != nil { return fmt.Errorf("error building CHANGE MASTER query: %v", err) } return c.Exec(ctx, query) } func buildChangeMasterQuery(changeMasterOpts ...ChangeMasterOpt) (string, error) { opts := ChangeMasterOpts{ Port: 3306, Gtid: "CurrentPos", } for _, setOpt := range changeMasterOpts { setOpt(&opts) } if opts.Host == "" { return "", errors.New("host must be provided") } if opts.User == "" || opts.Password == "" { return "", errors.New("credentials must be provided") } if opts.SSLEnabled && (opts.SSLCertPath == "" || opts.SSLKeyPath == "" || opts.SSLCAPath == "") { return "", errors.New("all SSL paths must be provided when SSL is enabled") } tpl := createTpl("change-master.sql", `CHANGE MASTER TO {{- if .SSLEnabled }} MASTER_SSL=1, MASTER_SSL_CERT='{{ .SSLCertPath }}', MASTER_SSL_KEY='{{ .SSLKeyPath }}', MASTER_SSL_CA='{{ .SSLCAPath }}', MASTER_SSL_VERIFY_SERVER_CERT=1, {{- end }} MASTER_HOST='{{ .Host }}', MASTER_PORT={{ .Port }}, MASTER_USER='{{ .User }}', MASTER_PASSWORD='{{ .Password }}', {{- with .Retries }} MASTER_CONNECT_RETRY={{ . }}, {{- end }} MASTER_USE_GTID={{ .Gtid }}; `) buf := new(bytes.Buffer) err := tpl.Execute(buf, opts) if err != nil { return "", fmt.Errorf("error rendering CHANGE MASTER template: %v", err) } return buf.String(), nil } const statusVariableSql = "SELECT variable_value FROM information_schema.global_status WHERE variable_name=?;" func (c *Client) StatusVariable(ctx context.Context, variable string) (string, error) { row := c.db.QueryRowContext(ctx, statusVariableSql, variable) var val string if err := row.Scan(&val); err != nil { return "", err } return val, nil } func (c *Client) StatusVariableInt(ctx context.Context, variable string) (int, error) { row := c.db.QueryRowContext(ctx, statusVariableSql, variable) var val int if err := row.Scan(&val); err != nil { return 0, err } return val, nil } func (c *Client) GaleraClusterSize(ctx context.Context) (int, error) { return c.StatusVariableInt(ctx, "wsrep_cluster_size") } func (c *Client) GaleraClusterStatus(ctx context.Context) (string, error) { return c.StatusVariable(ctx, "wsrep_cluster_status") } func (c *Client) GaleraLocalState(ctx context.Context) (string, error) { return c.StatusVariable(ctx, "wsrep_local_state_comment") } func (c *Client) MaxScaleConfigSyncVersion(ctx context.Context) (int, error) { row := c.db.QueryRowContext(ctx, "SELECT version FROM maxscale_config") var version int if err := row.Scan(&version); err != nil { return 0, err } return version, nil } func (c *Client) TruncateMaxScaleConfig(ctx context.Context) error { return c.Exec(ctx, "TRUNCATE TABLE maxscale_config") } func (c *Client) DropMaxScaleConfig(ctx context.Context) error { return c.Exec(ctx, "DROP TABLE maxscale_config") } func requireQuery(require *mariadbv1alpha1.TLSRequirements) (string, error) { if require == nil { return "", errors.New("TLS requirements must be set") } if err := require.Validate(); err != nil { return "", fmt.Errorf("invalid TLS requirements: %v", err) } var tlsOptions []string if require.SSL != nil && *require.SSL { tlsOptions = append(tlsOptions, "SSL") } if require.X509 != nil && *require.X509 { tlsOptions = append(tlsOptions, "X509") } if require.Issuer != nil && *require.Issuer != "" { tlsOptions = append(tlsOptions, fmt.Sprintf("ISSUER '%s'", *require.Issuer)) } if require.Subject != nil && *require.Subject != "" { tlsOptions = append(tlsOptions, fmt.Sprintf("SUBJECT '%s'", *require.Subject)) } if len(tlsOptions) == 0 { return "", errors.New("no valid TLS requirements specified") } return fmt.Sprintf("REQUIRE %s", strings.Join(tlsOptions, " AND ")), nil } func createTpl(name, t string) *template.Template { return template.Must(template.New(name).Parse(t)) } func toString(v interface{}) string { if v == nil { return "" } switch val := v.(type) { case []byte: return string(val) case string: return val case int64: return fmt.Sprintf("%d", val) default: return fmt.Sprintf("%v", val) } } func parseThreadRunning(s string) (bool, error) { switch strings.ToLower(s) { case "connecting": return true, nil case "preparing": return false, nil default: return parseBool(s) } } func parseBool(s string) (bool, error) { switch strings.ToLower(s) { case "yes", "on", "true", "1": return true, nil case "no", "off", "false", "0": return false, nil case "": return false, errors.New("invalid bool value: empty string") default: return false, fmt.Errorf("invalid bool value: %s", s) } }