Skip to content

Commit

Permalink
Merge pull request #90 from telekom-mms/feature/add-token-to-socket-a…
Browse files Browse the repository at this point in the history
…pi-msg

Add token to Unix Socket API message
  • Loading branch information
hwipl committed May 17, 2024
2 parents 4aa83af + 56ae1f1 commit dfd1f20
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 66 deletions.
43 changes: 36 additions & 7 deletions internal/api/message.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package api

import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"errors"
"io"
)

const (
// MaxPayloadLength is the maximum allowed length of a message payload.
MaxPayloadLength = 2097152
// TokenLength is the length of the message token in bytes.
TokenLength = 16
)

var (
// token is the message token.
token [TokenLength]byte
)

// Message types.
Expand All @@ -24,6 +31,7 @@ const (
type Header struct {
Type uint16
Length uint32
Token [TokenLength]byte
}

// Message is an API message.
Expand All @@ -34,13 +42,11 @@ type Message struct {

// NewMessage returns a new message with type t and payload p.
func NewMessage(t uint16, p []byte) *Message {
if len(p) > MaxPayloadLength {
return nil
}
return &Message{
Header: Header{
Type: t,
Length: uint32(len(p)),
Token: token,
},
Value: p,
}
Expand Down Expand Up @@ -69,8 +75,8 @@ func ReadMessage(r io.Reader) (*Message, error) {
if h.Type == TypeNone || h.Type >= TypeUndefined {
return nil, errors.New("invalid message type")
}
if h.Length > MaxPayloadLength {
return nil, errors.New("invalid message length")
if h.Token != token {
return nil, errors.New("invalid message token")
}

// read payload
Expand Down Expand Up @@ -107,3 +113,26 @@ func WriteMessage(w io.Writer, m *Message) error {

return nil
}

// GetToken generates and returns the message token as string. This should be
// used once on the server side before the server is started. Token must be
// passed to the client side.
func GetToken() (string, error) {
_, err := rand.Read(token[:])
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(token[:]), nil
}

// SetToken sets the message token from string. This should be used on the
// client side before sending requests to the server. Token must match token on
// the server side.
func SetToken(s string) error {
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return err
}
copy(token[:], b)
return nil
}
47 changes: 37 additions & 10 deletions internal/api/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"encoding/base64"
"errors"
"log"
"reflect"
Expand All @@ -24,12 +25,6 @@ func TestNewMessage(t *testing.T) {
t.Errorf("got %d, want %d", msg.Type, typ)
}
}

// invalid payload length
p := [MaxPayloadLength + 1]byte{}
if NewMessage(TypeOK, p[:]) != nil {
t.Error("should not create message with invalid payload length")
}
}

// TestNewOK tests NewOK.
Expand Down Expand Up @@ -62,11 +57,11 @@ func TestReadMessageErrors(t *testing.T) {
// invalid type
{Header: Header{Type: TypeUndefined}},

// invalid length
{Header: Header{Type: TypeOK, Length: MaxPayloadLength + 1}},

// short message
{Header: Header{Type: TypeOK, Length: MaxPayloadLength}},
{Header: Header{Type: TypeOK, Length: 4096}},

// invalid token
{Header: Header{Type: TypeOK, Token: [16]byte{1}}},
} {
if err := WriteMessage(buf, msg); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -132,3 +127,35 @@ func TestReadWriteMessage(t *testing.T) {
t.Errorf("got %v, want %v", got, want)
}
}

// TestGetSetToken tests GetToken and SetToken.
func TestGetSetToken(t *testing.T) {
// reset token after tests
defer func() { token = [TokenLength]byte{} }()

// get new test token
testToken, err := GetToken()
if err != nil {
t.Fatal(err)
}
s := base64.RawURLEncoding.EncodeToString(token[:])
if testToken != s {
t.Fatal("encoded token should match internal token")
}

// set token
if err := SetToken(testToken); err != nil {
t.Fatal(err)
}

// check token
s = base64.RawURLEncoding.EncodeToString(token[:])
if s != testToken {
t.Fatal("internal token should match encoded token")
}

// setting invalid token
if err := SetToken("not a valid encoded token!"); err == nil {
t.Fatal("invalid token should return error")
}
}
15 changes: 2 additions & 13 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package daemon

import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net"
"reflect"
Expand Down Expand Up @@ -344,13 +342,6 @@ func (d *Daemon) updateVPNConfig(request *api.Request) {
return
}

// check token
if configUpdate.Token != d.token {
log.Error("Daemon got invalid token in vpn config update")
request.Error("invalid token in config update message")
return
}

// handle config update for vpn (dis)connect
if configUpdate.Reason == "disconnect" {
d.updateVPNConfigDown()
Expand Down Expand Up @@ -508,13 +499,11 @@ func (d *Daemon) cleanup(ctx context.Context) {

// initToken creates the daemon token for client authentication.
func (d *Daemon) initToken() error {
// TODO: is this good enough for us?
b := make([]byte, 16)
_, err := rand.Read(b)
token, err := api.GetToken()
if err != nil {
return err
}
d.token = base64.RawURLEncoding.EncodeToString(b)
d.token = token
return nil
}

Expand Down
12 changes: 4 additions & 8 deletions internal/daemon/vpnconfigupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,20 @@ import (
// VPNConfigUpdate is a VPN configuration update.
type VPNConfigUpdate struct {
Reason string
Token string
Config *vpnconfig.Config
}

// Valid returns whether the config update is valid.
func (c *VPNConfigUpdate) Valid() bool {
switch c.Reason {
case "disconnect":
// token must be valid and config nil
if c.Token == "" || c.Config != nil {
// config must be nil
if c.Config != nil {
return false
}
case "connect":
// token and config must be valid
if c.Token == "" || c.Config == nil {
return false
}
if !c.Config.Valid() {
// config must be valid
if c.Config == nil || !c.Config.Valid() {
return false
}
default:
Expand Down
8 changes: 2 additions & 6 deletions internal/daemon/vpnconfigupdate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test invalid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Config = vpnconfig.New()

got = u.Valid()
want = false
if got != want {
t.Errorf("got %t, want %t", got, want)
}

// test invalid connect, no token and no config
// test invalid connect, no config
u = NewVPNConfigUpdate()
u.Reason = "connect"

Expand All @@ -42,7 +43,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test invalid connect, invalid config
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()
u.Config.Device.Name = "name is too long for a network device"

Expand All @@ -55,7 +55,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test valid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Token = "some test token"

got = u.Valid()
want = true
Expand All @@ -66,7 +65,6 @@ func TestVPNConfigUpdateValid(t *testing.T) {
// test valid connect
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()

got = u.Valid()
Expand All @@ -87,13 +85,11 @@ func TestVPNConfigUpdateJSON(t *testing.T) {
// valid disconnect
u = NewVPNConfigUpdate()
u.Reason = "disconnect"
u.Token = "some test token"
updates = append(updates, u)

// valid connect
u = NewVPNConfigUpdate()
u.Reason = "connect"
u.Token = "some test token"
u.Config = vpnconfig.New()
updates = append(updates, u)

Expand Down
23 changes: 6 additions & 17 deletions internal/vpncscript/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestRunClient(t *testing.T) {
return confUpdate
}

// test with maximum payload length
// test with varying payload lengths
server = api.NewServer(config)
go func() {
for r := range server.Requests() {
Expand All @@ -79,23 +79,12 @@ func TestRunClient(t *testing.T) {
if err := server.Start(); err != nil {
t.Fatal(err)
}
if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength)); err != nil {
t.Fatal(err)
}
server.Stop()

// test with more than maximum payload length
server = api.NewServer(config)
go func() {
for r := range server.Requests() {
r.Close()
for _, length := range []int{
2048, 4096, 8192, 65536, 2097152,
} {
if err := runClient(sockfile, getConfUpdate(length)); err != nil {
t.Errorf("length %d returned error: %v", length, err)
}
}()
if err := server.Start(); err != nil {
t.Fatal(err)
}
if err := runClient(sockfile, getConfUpdate(api.MaxPayloadLength+1)); err == nil {
t.Fatal("too long message should return error")
}
server.Stop()
}
6 changes: 6 additions & 0 deletions internal/vpncscript/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"

log "github.com/sirupsen/logrus"
"github.com/telekom-mms/oc-daemon/internal/api"
"github.com/telekom-mms/oc-daemon/internal/daemon"
)

Expand Down Expand Up @@ -47,6 +48,11 @@ func run(args []string) error {
socketFile = e.socketFile
}

// set token from environemt
if err := api.SetToken(e.token); err != nil {
return fmt.Errorf("VPNCScript could not set token: %w", err)
}

printDebugEnvironment()
log.WithField("env", e).Debug("VPNCScript parsed environment")

Expand Down
6 changes: 6 additions & 0 deletions internal/vpncscript/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ func TestRun(t *testing.T) {
t.Errorf("help should return ErrHelp, got: %v", err)
}

// test with invalid token
t.Setenv("oc_daemon_token", "this is not a valid encoded token!")
if err := run([]string{"test"}); err == nil {
t.Errorf("invalid token should return error")
}

// prepare environment with not existing sockfile
os.Clearenv()
sockfile := filepath.Join(t.TempDir(), "sockfile")
Expand Down
1 change: 0 additions & 1 deletion internal/vpncscript/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ func createConfig(env *env) (*vpnconfig.Config, error) {
func createConfigUpdate(env *env) (*daemon.VPNConfigUpdate, error) {
update := daemon.NewVPNConfigUpdate()
update.Reason = env.reason
update.Token = env.token
if env.reason == "connect" {
c, err := createConfig(env)
if err != nil {
Expand Down
4 changes: 0 additions & 4 deletions internal/vpncscript/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ func TestCreateConfigUpdate(t *testing.T) {

// create expected values based on test environment
reason := "connect"
token := "some token"
config := &vpnconfig.Config{
Gateway: net.IPv4(10, 1, 1, 1),
PID: 12345,
Expand Down Expand Up @@ -129,9 +128,6 @@ func TestCreateConfigUpdate(t *testing.T) {
if got.Reason != reason {
t.Errorf("got %s, want %s", got.Reason, reason)
}
if got.Token != token {
t.Errorf("got %s, want %s", got.Token, token)
}
if !reflect.DeepEqual(got.Config, config) {
t.Errorf("got:\n%#v\nwant:\n%#v", got.Config, config)
}
Expand Down

0 comments on commit dfd1f20

Please sign in to comment.