Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Go version of SQL Commenter #101

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go/sqlcommenter/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
module github.com/google/sqlcommenter/go/sqlcommenter
60 changes: 60 additions & 0 deletions go/sqlcommenter/sqlcommenter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sqlcommenter

import (
"net/url"
"sort"
"strings"
)

// Values maps a string key to a value for that key to attach to a SQL query
// in a comment. Implements the SQL Commenter spec:
// https://google.github.io/sqlcommenter
type Values map[string]string

// String returns the string representing all values according to the SQL
// Commenter spec.
func (vs Values) String() string {
if len(vs) == 0 {
return ""
}

pairs := make([]string, 0, len(vs))
for k, v := range vs {
if k == "" {
continue
}
pairs = append(pairs, serializeKey(k)+"="+serializeValue(v))
}

if len(pairs) == 0 {
return "" // we might have dropped only empty keys
}

// Spec requires sorted key-value pairs after running the serialization
// algorithm.
sort.Strings(pairs)

return "/*" + strings.Join(pairs, ",") + "*/"
}

// https://google.github.io/sqlcommenter/spec/#key-serialization-algorithm
func serializeKey(s string) string {
esc := urlEncode(s)
return escapeMeta(esc)
}

// https://google.github.io/sqlcommenter/spec/#value-serialization-algorithm
func serializeValue(s string) string {
esc := urlEncode(s)
return `'` + escapeMeta(esc) + `'`
}

func urlEncode(s string) string {
esc := url.QueryEscape(s)
// Go encodes spaces as "+"; use more standard %20.
return strings.Replace(esc, "+", "%20", -1)
}

func escapeMeta(s string) string {
return strings.Replace(s, `'`, `\'`, -1)
}
72 changes: 72 additions & 0 deletions go/sqlcommenter/sqlcommenter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package sqlcommenter

import (
"strings"
"testing"
)

func TestValues_String(t *testing.T) {
tests := []struct {
name string
vs Values
want string
}{
{name: "nil", vs: nil, want: ""},
{name: "empty", vs: Values{}, want: ""},
{name: "empty cast", vs: Values(map[string]string{}), want: ""},
{
name: "drop empty key",
vs: Values(map[string]string{"": "val"}),
want: "",
},
{
name: "one",
vs: Values(map[string]string{"key": "val"}),
want: "/*key='val'*/",
},
{
name: "two",
vs: Values(map[string]string{"a": "1", "b": "2"}),
want: "/*a='1',b='2'*/",
},
{
name: "two reversed",
vs: Values(map[string]string{"b": "2", "a": "1"}), // technically, Go map iteration is random
want: "/*a='1',b='2'*/",
},
{
name: "name=DROP TABLE FOO",
vs: Values(map[string]string{"name": "DROP TABLE FOO"}),
want: "/*name='DROP%20TABLE%20FOO'*/",
},
{
name: `name''=DROP TABLE USERS'`,
vs: Values(map[string]string{"name''": `DROP TABLE USERS'`}),
want: `/*name%27%27='DROP%20TABLE%20USERS%27'*/`,
},
{
name: `exhibit`, // https://google.github.io/sqlcommenter/spec/#sql-commenter-exhibit
vs: Values(map[string]string{
"action": `%2Fparam*d`,
"controller": `index`,
"framework": `spring`,
"traceparent": `00-5bd66ef5095369c7b0d1f8f4bd33716a-c532cb4098ac3dd2-01`,
"tracestate": `congo%3Dt61rcWkgMzE%2Crojo%3D00f067aa0ba902b7`,
}),
want: "/*" + strings.Join([]string{
"action='%252Fparam%2Ad'",
"controller='index'",
"framework='spring'",
"traceparent='00-5bd66ef5095369c7b0d1f8f4bd33716a-c532cb4098ac3dd2-01'",
"tracestate='congo%253Dt61rcWkgMzE%252Crojo%253D00f067aa0ba902b7'",
}, ",") + "*/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.vs.String(); got != tt.want {
t.Errorf("\nwant: %v\ngot: %v", tt.want, got)
}
})
}
}