Skip to content

Commit

Permalink
Add max_parallel parameter to MySQL backend. (#2760)
Browse files Browse the repository at this point in the history
* Add max_parallel parameter to MySQL backend.

This limits the number of concurrent connections, so that vault does not die
suddenly from "Too many connections".

This can happen when e.g. vault starts up, and tries to load all the
existing leases in parallel. At the time of writing this, the value
ExpirationRestoreWorkerCount in vault/helper/consts/const.go is set to
64, meaning that if there are enough leases in the vault's DB, it will
generate AT LEAST 64 concurrent connections to MySQL when loading the
data during start-up. On certain configurations, e.g. smaller AWS
RDS/Aurora instances, this will cause Vault to fail startup.

* Fix a typo in mysql storage readme
  • Loading branch information
ikatson authored and briankassouf committed Jun 1, 2017
1 parent 5adcb9c commit 32c7efe
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
31 changes: 31 additions & 0 deletions physical/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (
"io/ioutil"
"net/url"
"sort"
"strconv"
"strings"
"time"

log "github.com/mgutz/logxi/v1"

"github.com/armon/go-metrics"
mysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/errwrap"
)

// Unreserved tls key
Expand All @@ -28,11 +30,14 @@ type MySQLBackend struct {
client *sql.DB
statements map[string]*sql.Stmt
logger log.Logger
permitPool *PermitPool
}

// newMySQLBackend constructs a MySQL backend using the given API client and
// server address and credential for accessing mysql database.
func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error) {
var err error

// Get the MySQL credentials to perform read/write operations.
username, ok := conf["username"]
if !ok || username == "" {
Expand Down Expand Up @@ -60,6 +65,18 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error)
}
dbTable := database + "." + table

maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
}
if logger.IsDebug() {
logger.Debug("mysql: max_parallel set", "max_parallel", maxParInt)
}
}

dsnParams := url.Values{}
tlsCaFile, ok := conf["tls_ca_file"]
if ok {
Expand Down Expand Up @@ -95,6 +112,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error)
client: db,
statements: make(map[string]*sql.Stmt),
logger: logger,
permitPool: NewPermitPool(maxParInt),
}

// Prepare all the statements required
Expand All @@ -110,6 +128,7 @@ func newMySQLBackend(conf map[string]string, logger log.Logger) (Backend, error)
return nil, err
}
}

return m, nil
}

Expand All @@ -127,6 +146,9 @@ func (m *MySQLBackend) prepare(name, query string) error {
func (m *MySQLBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())

m.permitPool.Acquire()
defer m.permitPool.Release()

_, err := m.statements["put"].Exec(entry.Key, entry.Value)
if err != nil {
return err
Expand All @@ -138,6 +160,9 @@ func (m *MySQLBackend) Put(entry *Entry) error {
func (m *MySQLBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())

m.permitPool.Acquire()
defer m.permitPool.Release()

var result []byte
err := m.statements["get"].QueryRow(key).Scan(&result)
if err == sql.ErrNoRows {
Expand All @@ -158,6 +183,9 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) {
func (m *MySQLBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())

m.permitPool.Acquire()
defer m.permitPool.Release()

_, err := m.statements["delete"].Exec(key)
if err != nil {
return err
Expand All @@ -170,6 +198,9 @@ func (m *MySQLBackend) Delete(key string) error {
func (m *MySQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())

m.permitPool.Acquire()
defer m.permitPool.Release()

// Add the % wildcard to the prefix to do the prefix search
likePrefix := prefix + "%"
rows, err := m.statements["list"].Query(likePrefix)
Expand Down
3 changes: 3 additions & 0 deletions website/source/docs/configuration/storage/mysql.html.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ storage "mysql" {
- `tls_ca_file` `(string: "")` – Specifies the path to the CA certificate to
connect using TLS.

- `max_parallel` `(string: "128")` – Specifies the maximum number of concurrent
requests to MySQL.

Additionally, Vault requires the following authentication information.

- `username` `(string: <required>)` – Specifies the MySQL username to connect to
Expand Down

0 comments on commit 32c7efe

Please sign in to comment.