diff --git a/client/manager_sql.go b/client/manager_sql.go new file mode 100644 index 00000000000..140ad8a7aee --- /dev/null +++ b/client/manager_sql.go @@ -0,0 +1,140 @@ +package client + +import ( + "encoding/json" + "github.com/imdario/mergo" + "github.com/jmoiron/sqlx" + "github.com/ory-am/fosite" + "github.com/pkg/errors" +) + +var sqlSchema = []string{ + `CREATE TABLE IF NOT EXISTS hydra_client ( + id varchar(255) NOT NULL PRIMARY KEY, + version int NOT NULL DEFAULT 0, + client json NOT NULL)`, +} + +type SQLManager struct { + Hasher fosite.Hasher + DB *sqlx.DB +} + +type sqlData struct { + ID string `db:"id"` + Version int `db:"version"` + Client []byte `db:"client"` +} + +// CreateSchemas creates ladon_policy tables +func (s *SQLManager) CreateSchemas() error { + for _, query := range sqlSchema { + if _, err := s.DB.Exec(query); err != nil { + return errors.Wrapf(err, "Could not create schema:\n%s", query) + } + } + return nil +} + +func (m *SQLManager) GetConcreteClient(id string) (*Client, error) { + var d sqlData + var c Client + if err := m.DB.Get(&d, m.DB.Rebind("SELECT * FROM hydra_client WHERE id=?"), id); err != nil { + return nil, errors.Wrap(err, "") + } else if err := json.Unmarshal(d.Client, &c); err != nil { + return nil, errors.Wrap(err, "") + } + + return &c, nil +} + +func (m *SQLManager) GetClient(id string) (fosite.Client, error) { + return m.GetConcreteClient(id) +} + +func (m *SQLManager) UpdateClient(c *Client) error { + o, err := m.GetClient(c.ID) + if err != nil { + return err + } + + if c.Secret == "" { + c.Secret = string(o.GetHashedSecret()) + } else { + h, err := m.Hasher.Hash([]byte(c.Secret)) + if err != nil { + return errors.Wrap(err, "") + } + c.Secret = string(h) + } + if err := mergo.Merge(c, o); err != nil { + return errors.Wrap(err, "") + } + + b, err := json.Marshal(c) + if err != nil { + return errors.Wrap(err, "") + } + + if _, err := m.DB.NamedExec(`UPDATE hydra_client SET id=:id, client=:client WHERE id=:id`, &sqlData{ + ID: c.ID, + Client: b, + }); err != nil { + return errors.Wrap(err, "") + } + return nil +} + +func (m *SQLManager) Authenticate(id string, secret []byte) (*Client, error) { + c, err := m.GetConcreteClient(id) + if err != nil { + return nil, errors.Wrap(err, "") + } + + if err := m.Hasher.Compare(c.GetHashedSecret(), secret); err != nil { + return nil, errors.Wrap(err, "") + } + + return c, nil +} + +func (m *SQLManager) CreateClient(c *Client) error { + b, err := json.Marshal(c) + if err != nil { + return errors.Wrap(err, "") + } + + if _, err := m.DB.NamedExec(`INSERT INTO hydra_client (id, client, version) VALUES (:id, :client, :version)`, &sqlData{ + ID: c.ID, + Client: b, + Version: 0, + }); err != nil { + return errors.Wrap(err, "") + } + return nil +} + +func (m *SQLManager) DeleteClient(id string) error { + if _, err := m.DB.Exec(m.DB.Rebind(`DELETE FROM hydra_client WHERE id=?`), id); err != nil { + return errors.Wrap(err, "") + } + return nil +} + +func (m *SQLManager) GetClients() (clients map[string]Client, err error) { + var d = []sqlData{} + clients = make(map[string]Client) + + if err := m.DB.Select(&d, "SELECT * FROM hydra_client"); err != nil { + return nil, errors.Wrap(err, "") + } + + for _, k := range d { + var c Client + if err := json.Unmarshal(k.Client, &c); err != nil { + return nil, errors.Wrap(err, "") + } + clients[k.ID] = c + } + return clients, nil +} diff --git a/client/manager_test.go b/client/manager_test.go index cb403045b63..0020bb4d5ef 100644 --- a/client/manager_test.go +++ b/client/manager_test.go @@ -11,6 +11,7 @@ import ( "os" "time" + "github.com/jmoiron/sqlx" "github.com/julienschmidt/httprouter" "github.com/ory-am/dockertest" "github.com/ory-am/fosite" @@ -21,6 +22,7 @@ import ( "github.com/ory-am/ladon" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/context" ) @@ -63,8 +65,77 @@ func init() { } var rethinkManager *RethinkManager +var containers = []dockertest.ContainerID{} func TestMain(m *testing.M) { + defer func() { + for _, c := range containers { + c.KillRemove() + } + }() + + connectPG() + connectToRethinkDB() + connectMySQL() + + retCode := m.Run() + os.Exit(retCode) +} +func connectMySQL() { + var db *sqlx.DB + c, err := dockertest.ConnectToMySQL(15, time.Second, func(url string) bool { + var err error + db, err = sqlx.Open("mysql", url) + if err != nil { + log.Printf("Got error in mysql connector: %s", err) + return false + } + return db.Ping() == nil + }) + + if err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + containers = append(containers, c) + s := &SQLManager{DB: db, Hasher: &fosite.BCrypt{WorkFactor: 4}} + + if err = s.CreateSchemas(); err != nil { + log.Fatalf("Could not create postgres schema: %v", err) + } + + clientManagers["mysql"] = s + containers = append(containers, c) +} + +func connectPG() { + var db *sqlx.DB + c, err := dockertest.ConnectToPostgreSQL(15, time.Second, func(url string) bool { + var err error + db, err = sqlx.Open("postgres", url) + if err != nil { + log.Printf("Got error in postgres connector: %s", err) + return false + } + return db.Ping() == nil + }) + + if err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + + containers = append(containers, c) + s := &SQLManager{DB: db, Hasher: &fosite.BCrypt{WorkFactor: 4}} + + if err = s.CreateSchemas(); err != nil { + log.Fatalf("Could not create postgres schema: %v", err) + } + + clientManagers["postgres"] = s + containers = append(containers, c) +} + +func connectToRethinkDB() { var session *r.Session var err error @@ -92,17 +163,13 @@ func TestMain(m *testing.M) { time.Sleep(100 * time.Millisecond) return true }) - if session != nil { - defer session.Close() - } + if err != nil { log.Fatalf("Could not connect to database: %s", err) } - clientManagers["rethink"] = rethinkManager - retCode := m.Run() - c.KillRemove() - os.Exit(retCode) + containers = append(containers, c) + clientManagers["rethink"] = rethinkManager } func TestAuthenticateClient(t *testing.T) { @@ -180,12 +247,12 @@ func TestColdStartRethinkManager(t *testing.T) { time.Sleep(time.Second / 2) rethinkManager.Clients = make(map[string]Client) - assert.Nil(t, rethinkManager.ColdStart()) + require.Nil(t, rethinkManager.ColdStart()) c1, err := rethinkManager.GetClient("foo") - assert.Nil(t, err) + require.Nil(t, err) c2, err := rethinkManager.GetClient("bar") - assert.Nil(t, err) + require.Nil(t, err) assert.NotEqual(t, c1, c2) assert.Equal(t, "foo", c1.GetID())