Skip to content

Commit

Permalink
added url parsing util since stdlib is very finnicky when managing gr…
Browse files Browse the repository at this point in the history
…pc, http, and unix
  • Loading branch information
brennanjl committed Nov 30, 2023
1 parent bb3bc34 commit d8cbcc4
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 173 deletions.
3 changes: 2 additions & 1 deletion cmd/internal/display/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ func BindOutputFormatFlag(cmd *cobra.Command) {
// BindSilenceFlag binds the silence flag to the passed command.
// If bound, the command will silence logs.
// If true, display commands will not print to stdout or stderr.
// The flag will be bound to all subcommands of the given command.
func BindSilenceFlag(cmd *cobra.Command) {
cmd.Flags().BoolP("silence", "S", false, "Silence logs")
cmd.PersistentFlags().BoolP("silence", "S", false, "Silence logs")
}

// ShouldSilence returns the value of the silence flag
Expand Down
9 changes: 5 additions & 4 deletions cmd/kwil-admin/cmds/common/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ import (
// BindRPCFlags binds the RPC flags to the given command.
// This includes an rpcserver flag, and the TLS flags.
// These flags can be used to create an admin service client.
// The flags will be bound to all subcommands of the given command.
func BindRPCFlags(cmd *cobra.Command) {
cmd.Flags().StringP("rpcserver", "s", "127.0.0.1:50151", "admin RPC server address (either unix or tcp) [default: unix:///tmp/kwil_admin.sock]")
cmd.PersistentFlags().StringP("rpcserver", "s", "unix:///tmp/kwil_admin.sock", "admin RPC server address (either unix or tcp) [default: unix:///tmp/kwil_admin.sock]")

cmd.Flags().String("authrpc-cert", "", "kwild's TLS certificate")
cmd.Flags().String("tlskey", "auth.key", "kwil-admin's TLS key file to establish a mTLS (authenticated) connection [default: auth.key]")
cmd.Flags().String("tlscert", "auth.cert", "kwil-admin's TLS certificate file for server to authenticate us [default: auth.cert]")
cmd.PersistentFlags().String("authrpc-cert", "", "kwild's TLS certificate")
cmd.PersistentFlags().String("tlskey", "auth.key", "kwil-admin's TLS key file to establish a mTLS (authenticated) connection [default: auth.key]")
cmd.PersistentFlags().String("tlscert", "auth.cert", "kwil-admin's TLS certificate file for server to authenticate us [default: auth.cert]")
}

// GetRPCServerFlag returns the RPC flag from the given command.
Expand Down
4 changes: 2 additions & 2 deletions cmd/kwil-admin/cmds/validators/join-status.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

"github.com/kwilteam/kwil-db/cmd/internal/display"
"github.com/kwilteam/kwil-db/cmd/kwil-admin/cmds/common"
"github.com/kwilteam/kwil-db/core/client"
"github.com/kwilteam/kwil-db/core/rpc/client"
"github.com/kwilteam/kwil-db/internal/validators"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -44,7 +44,7 @@ func joinStatusCmd() *cobra.Command {
data, err := clt.JoinStatus(ctx, pubkeyBts)
if err != nil {
if errors.Is(err, client.ErrNotFound) {
return errors.New("no active join request for that validator")
return display.PrintErr(cmd, errors.New("no active join request for that validator"))
}
return err
}
Expand Down
5 changes: 4 additions & 1 deletion cmd/kwil-admin/cmds/validators/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ func (r *respValSets) MarshalText() ([]byte, error) {
var msg bytes.Buffer
msg.WriteString("Current validator set:\n")
for i, v := range r.Data {
msg.WriteString(fmt.Sprintf("% 3d. %s\n", i, v))
msg.WriteString(fmt.Sprintf("% 3d. %s", i, v))
if i != len(r.Data)-1 {
msg.WriteString("\n")
}
}

return msg.Bytes(), nil
Expand Down
28 changes: 10 additions & 18 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import (
"errors"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -46,6 +44,7 @@ import (
functionpb "github.com/kwilteam/kwil-db/core/rpc/protobuf/function/v0"
txpb "github.com/kwilteam/kwil-db/core/rpc/protobuf/tx/v1"
"github.com/kwilteam/kwil-db/core/rpc/transport"
"github.com/kwilteam/kwil-db/core/utils/url"

abciTypes "github.com/cometbft/cometbft/abci/types"
cmtEd "github.com/cometbft/cometbft/crypto/ed25519"
Expand Down Expand Up @@ -417,25 +416,18 @@ func buildHealthSvc(d *coreDependencies) *healthsvc.Server {
}

func buildAdminService(d *coreDependencies, closer *closeFuncs, admsvc admpb.AdminServiceServer, txsvc txpb.TxServiceServer) *kwilgrpc.Server {
addr := d.cfg.AppCfg.AdminListenAddress
// if listen address does not have unix:// or tcp:// prefix, assume tcp://
if !strings.HasPrefix(d.cfg.AppCfg.AdminListenAddress, "unix://") && !strings.HasPrefix(d.cfg.AppCfg.AdminListenAddress, "tcp://") {
addr = "tcp://" + addr
}

// parse AdminListenAddress to see if it is tcp or unix
u, err := url.Parse(addr)
u, err := url.ParseURL(d.cfg.AppCfg.AdminListenAddress)
if err != nil {
failBuild(err, "failed to build grpc server")
failBuild(err, "failed to build admin service")
}

switch u.Scheme {
default:
failBuild(err, "unknown admin service protocol "+u.Scheme)
case "tcp":
failBuild(err, "unknown admin service protocol "+u.Scheme.String())
case url.TCP:

// if tcp, we need to set up TLS
lis, err := net.Listen("tcp", ":"+u.Port())
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", u.Port))
if err != nil {
failBuild(err, "failed to build grpc server")
}
Expand Down Expand Up @@ -487,9 +479,9 @@ func buildAdminService(d *coreDependencies, closer *closeFuncs, admsvc admpb.Adm
grpc_health_v1.RegisterHealthServer(grpcServer, buildHealthSvc(d))

return grpcServer
case "unix":
case url.UNIX:
// if unix, we need to set up unix socket
err = os.MkdirAll(filepath.Dir(u.Path), 0755) // ensure parent dir exists
err = os.MkdirAll(filepath.Dir(u.Target), 0755) // ensure parent dir exists
if err != nil {
failBuild(err, "failed to build grpc server")
}
Expand All @@ -499,11 +491,11 @@ func buildAdminService(d *coreDependencies, closer *closeFuncs, admsvc admpb.Adm
// suggested approach here (not sure about this for obvious reasons):
// https://gist.github.com/hakobe/6f70d69b8c5243117787fd488ae7fbf2

err = syscall.Unlink(u.Path)
err = syscall.Unlink(u.Target)
if err != nil && !os.IsNotExist(err) {
failBuild(err, "failed to build grpc server")
}
lis, err := net.Listen("unix", u.Path)
lis, err := net.Listen("unix", u.Target)
if err != nil {
failBuild(err, "failed to build grpc server")
}
Expand Down
13 changes: 8 additions & 5 deletions core/adminclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ package adminclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"

"github.com/kwilteam/kwil-db/core/log"
admingrpc "github.com/kwilteam/kwil-db/core/rpc/client/admin/grpc"
txGrpc "github.com/kwilteam/kwil-db/core/rpc/client/user/grpc"
"github.com/kwilteam/kwil-db/core/rpc/transport"
"github.com/kwilteam/kwil-db/core/types"
adminTypes "github.com/kwilteam/kwil-db/core/types/admin"
"github.com/kwilteam/kwil-db/core/utils/url"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -82,7 +83,7 @@ func New(ctx context.Context, target string, opts ...AdminClientOpt) (*AdminClie
log: log.NewNoOp(),
}

parsedTarget, err := url.Parse(target)
parsedTarget, err := url.ParseURL(target)
if err != nil {
return nil, err
}
Expand All @@ -94,7 +95,7 @@ func New(ctx context.Context, target string, opts ...AdminClientOpt) (*AdminClie
dialOpts := []grpc.DialOption{}

switch parsedTarget.Scheme {
case "tcp", "": // default to grpc
case url.TCP: // default to grpc
if c.kwildCertFile != "" || c.clientKeyFile != "" || c.clientCertFile != "" {
// tcp + tls

Expand All @@ -108,14 +109,16 @@ func New(ctx context.Context, target string, opts ...AdminClientOpt) (*AdminClie
// tcp + no tls
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
case "unix":
case url.UNIX:
dialOpts = append(dialOpts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return net.Dial("unix", s)
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
default:
return nil, fmt.Errorf("unknown scheme %q", parsedTarget.Scheme)
}

// we dial a normal grpc connection, and then wrap it with the services
conn, err := grpc.DialContext(ctx, target, dialOpts...)
conn, err := grpc.DialContext(ctx, parsedTarget.Target, dialOpts...)
if err != nil {
return nil, err
}
Expand Down
5 changes: 0 additions & 5 deletions core/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/url"
Expand Down Expand Up @@ -39,10 +38,6 @@ type TxClient interface {
EstimateCost(ctx context.Context, tx *transactions.Transaction) (*big.Int, error)
}

var (
ErrNotFound = errors.New("not found")
)

// Client is a Kwil client that can interact with the main public Kwil RPC.
type Client struct {
txClient TxClient
Expand Down
19 changes: 10 additions & 9 deletions core/rpc/client/admin/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"time"

"github.com/kwilteam/kwil-db/core/rpc/client"
admpb "github.com/kwilteam/kwil-db/core/rpc/protobuf/admin/v0"
"github.com/kwilteam/kwil-db/core/types"
adminTypes "github.com/kwilteam/kwil-db/core/types/admin"
Expand All @@ -26,7 +27,7 @@ func NewAdminClient(conn *grpc.ClientConn) *GrpcAdminClient {
func (c *GrpcAdminClient) Version(ctx context.Context) (string, error) {
resp, err := c.client.Version(ctx, &admpb.VersionRequest{})
if err != nil {
return "", err
return "", client.ConvertGRPCErr(err)
}
return resp.VersionString, nil
}
Expand All @@ -47,7 +48,7 @@ func convertNodeInfo(ni *admpb.NodeInfo) *adminTypes.NodeInfo {
func (c *GrpcAdminClient) Status(ctx context.Context) (*adminTypes.Status, error) {
resp, err := c.client.Status(ctx, &admpb.StatusRequest{})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
return &adminTypes.Status{
Node: convertNodeInfo(resp.Node),
Expand All @@ -68,7 +69,7 @@ func (c *GrpcAdminClient) Status(ctx context.Context) (*adminTypes.Status, error
func (c *GrpcAdminClient) Peers(ctx context.Context) ([]*adminTypes.PeerInfo, error) {
resp, err := c.client.Peers(ctx, &admpb.PeersRequest{})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
peers := make([]*adminTypes.PeerInfo, len(resp.Peers))
for i, pbPeer := range resp.Peers {
Expand All @@ -86,7 +87,7 @@ func (c *GrpcAdminClient) Peers(ctx context.Context) ([]*adminTypes.PeerInfo, er
func (c *GrpcAdminClient) Approve(ctx context.Context, publicKey []byte) ([]byte, error) {
resp, err := c.client.Approve(ctx, &admpb.ApproveRequest{Pubkey: publicKey})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
return resp.TxHash, nil
}
Expand All @@ -96,7 +97,7 @@ func (c *GrpcAdminClient) Approve(ctx context.Context, publicKey []byte) ([]byte
func (c *GrpcAdminClient) Join(ctx context.Context) ([]byte, error) {
resp, err := c.client.Join(ctx, &admpb.JoinRequest{})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
return resp.TxHash, nil
}
Expand All @@ -106,7 +107,7 @@ func (c *GrpcAdminClient) Join(ctx context.Context) ([]byte, error) {
func (c *GrpcAdminClient) Leave(ctx context.Context) ([]byte, error) {
resp, err := c.client.Leave(ctx, &admpb.LeaveRequest{})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
return resp.TxHash, nil
}
Expand All @@ -116,7 +117,7 @@ func (c *GrpcAdminClient) Leave(ctx context.Context) ([]byte, error) {
func (c *GrpcAdminClient) Remove(ctx context.Context, publicKey []byte) ([]byte, error) {
resp, err := c.client.Remove(ctx, &admpb.RemoveRequest{Pubkey: publicKey})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
return resp.TxHash, nil
}
Expand All @@ -125,7 +126,7 @@ func (c *GrpcAdminClient) Remove(ctx context.Context, publicKey []byte) ([]byte,
func (c *GrpcAdminClient) JoinStatus(ctx context.Context, pubkey []byte) (*types.JoinRequest, error) {
resp, err := c.client.JoinStatus(ctx, &admpb.JoinStatusRequest{Pubkey: pubkey})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}

total := len(resp.ApprovedValidators) + len(resp.PendingValidators)
Expand All @@ -151,7 +152,7 @@ func (c *GrpcAdminClient) JoinStatus(ctx context.Context, pubkey []byte) (*types
func (c *GrpcAdminClient) ListValidators(ctx context.Context) ([]*types.Validator, error) {
resp, err := c.client.ListValidators(ctx, &admpb.ListValidatorsRequest{})
if err != nil {
return nil, err
return nil, client.ConvertGRPCErr(err)
}
validators := make([]*types.Validator, len(resp.Validators))
for i, v := range resp.Validators {
Expand Down
28 changes: 27 additions & 1 deletion core/rpc/client/error.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
package client

import "errors"
import (
"errors"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
ErrInvalidSignature = errors.New("invalid signature")
// ErrUnauthorized is returned when the client is not authenticated
// It is the equivalent of http status code 401
ErrUnauthorized = errors.New("unauthorized")
ErrNotFound = errors.New("not found")
)

// convertErr will convert the error to a known type, if possible.
// It is expected that the error is from a gRPC call.
func ConvertGRPCErr(err error) error {
statusError, ok := status.FromError(err)
if !ok {
return fmt.Errorf("unrecognized error: %w", err)
}

switch statusError.Code() {
case codes.OK:
// this should never happen?
return fmt.Errorf("unexpected OK status code returned error")
case codes.NotFound:
return ErrNotFound
}

return fmt.Errorf("%v (%d)", statusError.Message(), statusError.Code())
}
7 changes: 7 additions & 0 deletions core/utils/url/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package url

import "errors"

var (
ErrUnknownScheme = errors.New("unknown scheme")
)
Loading

0 comments on commit d8cbcc4

Please sign in to comment.