Skip to content

Commit

Permalink
refactor: fix loading of nameservers
Browse files Browse the repository at this point in the history
fixes #106
  • Loading branch information
mr-karan committed Jul 1, 2024
1 parent db64c43 commit f1959e1
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 193 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v4
uses: actions/setup-go@v5
with:
go-version: "1.22"
- name: Log in to the Container registry
Expand Down
7 changes: 5 additions & 2 deletions cmd/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) {
}

// Load Nameservers.
err = app.LoadNameservers()
if err != nil {
if err := app.LoadNameservers(); err != nil {
app.Logger.WithError(err).Error("error loading nameservers")
sendErrorResponse(w, fmt.Sprintf("Error looking up for records."), http.StatusInternalServerError, nil)
return
}

app.Logger.WithField("nameservers", app.Nameservers).Debug("Loaded nameservers")

// Load Resolvers.
rslvrs, err := resolvers.LoadResolvers(resolvers.Options{
Nameservers: app.Nameservers,
Expand All @@ -91,6 +92,8 @@ func handleLookup(w http.ResponseWriter, r *http.Request) {
}
app.Resolvers = rslvrs

app.Logger.WithField("resolvers", app.Resolvers).Debug("Loaded resolvers")

var responses []resolvers.Response
for _, q := range app.Questions {
for _, rslv := range app.Resolvers {
Expand Down
215 changes: 113 additions & 102 deletions cmd/doggo/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,113 +16,47 @@ import (
)

var (
// Version and date of the build. This is injected at build-time.
buildVersion = "unknown"
buildDate = "unknown"
logger = utils.InitLogger()
k = koanf.New(".")
)

func main() {
// Initialize app.
app := app.New(logger, buildVersion)
f := setupFlags()

// Configure Flags.
f := flag.NewFlagSet("config", flag.ContinueOnError)

// Custom Help Text.
f.Usage = renderCustomHelp

// Query Options.
f.StringSliceP("query", "q", []string{}, "Domain name to query")
f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)")
f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)")
f.StringSliceP("nameserver", "n", []string{}, "Address of the nameserver to send packets to")
f.BoolP("reverse", "x", false, "Performs a DNS Lookup for an IPv4 or IPv6 address. Sets the query type and class to PTR and IN respectively.")

// Resolver Options
f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.")
f.Bool("search", true, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overridden by `ndots` flag.")
f.Int("ndots", -1, "Specify the ndots parameter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf")
f.BoolP("ipv4", "4", false, "Use IPv4 only")
f.BoolP("ipv6", "6", false, "Use IPv6 only")
f.String("strategy", "all", "Strategy to query nameservers in resolv.conf file (`all`, `random`, `first`)")
f.String("tls-hostname", "", "Provide a hostname for doing verification of the certificate if the provided DoT nameserver is an IP")
f.Bool("skip-hostname-verification", false, "Skip TLS Hostname Verification")

// Output Options
f.BoolP("json", "J", false, "Set the output format as JSON")
f.Bool("short", false, "Short output format")
f.Bool("time", false, "Display how long it took for the response to arrive")
f.Bool("color", true, "Show colored output")
f.Bool("debug", false, "Enable debug mode")

f.Bool("version", false, "Show version of doggo")

// Parse and Load Flags.
err := f.Parse(os.Args[1:])
if err != nil {
app.Logger.WithError(err).Error("error parsing flags")
app.Logger.Exit(2)
}
if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil {
app.Logger.WithError(err).Error("error loading flags")
f.Usage()
if err := parseAndLoadFlags(f); err != nil {
app.Logger.WithError(err).Error("Error parsing or loading flags")
app.Logger.Exit(2)
}

// If version flag is set, output version and quit.
if k.Bool("version") {
fmt.Printf("%s - %s\n", buildVersion, buildDate)
app.Logger.Exit(0)
}

// Set log level.
if k.Bool("debug") {
// Set logger level
app.Logger.SetLevel(logrus.DebugLevel)
} else {
app.Logger.SetLevel(logrus.InfoLevel)
}
setupLogging(&app)

// Unmarshall flags to the app.
err = k.Unmarshal("", &app.QueryFlags)
if err != nil {
app.Logger.WithError(err).Error("error loading args")
if err := k.Unmarshal("", &app.QueryFlags); err != nil {
app.Logger.WithError(err).Error("Error loading args")
app.Logger.Exit(2)
}

// Load all `non-flag` arguments
// which will be parsed separately.
nsvrs, qt, qc, qn := loadUnparsedArgs(f.Args())
app.QueryFlags.Nameservers = append(app.QueryFlags.Nameservers, nsvrs...)
app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, qt...)
app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, qc...)
app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...)

// Check if reverse flag is passed. If it is, then set
// query type as PTR and query class as IN.
// Modify query name like 94.2.0.192.in-addr.arpa if it's an IPv4 address.
// Use IP6.ARPA nibble format otherwise.
loadNameservers(&app, f.Args())

if app.QueryFlags.ReverseLookup {
app.ReverseLookup()
}

// Load fallbacks.
app.LoadFallbacks()

// Load Questions.
app.PrepareQuestions()

// Load Nameservers.
err = app.LoadNameservers()
if err != nil {
app.Logger.WithError(err).Error("error loading nameservers")
if err := app.LoadNameservers(); err != nil {
app.Logger.WithError(err).Error("Error loading nameservers")
app.Logger.Exit(2)
}

// Load Resolvers.
rslvrs, err := resolvers.LoadResolvers(resolvers.Options{
Nameservers: app.Nameservers,
UseIPv4: app.QueryFlags.UseIPv4,
Expand All @@ -136,23 +70,96 @@ func main() {
TLSHostname: app.QueryFlags.TLSHostname,
})
if err != nil {
app.Logger.WithError(err).Error("error loading resolver")
app.Logger.WithError(err).Error("Error loading resolver")
app.Logger.Exit(2)
}
app.Resolvers = rslvrs

// Run the app.
app.Logger.Debug("Starting doggo 🐶")
if len(app.QueryFlags.QNames) == 0 {
f.Usage()
app.Logger.Exit(0)
}

// Resolve Queries.
var (
responses []resolvers.Response
responseErrors []error
)
responses, responseErrors := resolveQueries(&app)

outputResults(&app, responses, responseErrors)

app.Logger.Exit(0)
}

func setupFlags() *flag.FlagSet {
f := flag.NewFlagSet("config", flag.ContinueOnError)
f.Usage = renderCustomHelp

f.StringSliceP("query", "q", []string{}, "Domain name to query")
f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)")
f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)")
f.StringSliceP("nameserver", "n", []string{}, "Address of the nameserver to send packets to")
f.BoolP("reverse", "x", false, "Performs a DNS Lookup for an IPv4 or IPv6 address")

f.Int("timeout", 5, "Sets the timeout for a query to T seconds")
f.Bool("search", true, "Use the search list provided in resolv.conf")
f.Int("ndots", -1, "Specify the ndots parameter")
f.BoolP("ipv4", "4", false, "Use IPv4 only")
f.BoolP("ipv6", "6", false, "Use IPv6 only")
f.String("strategy", "all", "Strategy to query nameservers in resolv.conf file")
f.String("tls-hostname", "", "Hostname for certificate verification")
f.Bool("skip-hostname-verification", false, "Skip TLS Hostname Verification")

f.BoolP("json", "J", false, "Set the output format as JSON")
f.Bool("short", false, "Short output format")
f.Bool("time", false, "Display how long the response took")
f.Bool("color", true, "Show colored output")
f.Bool("debug", false, "Enable debug mode")

f.Bool("version", false, "Show version of doggo")

return f
}

func parseAndLoadFlags(f *flag.FlagSet) error {
if err := f.Parse(os.Args[1:]); err != nil {
return fmt.Errorf("error parsing flags: %w", err)
}
if err := k.Load(posflag.Provider(f, ".", k), nil); err != nil {
return fmt.Errorf("error loading flags: %w", err)
}
return nil
}

func setupLogging(app *app.App) {
if k.Bool("debug") {
app.Logger.SetLevel(logrus.DebugLevel)
} else {
app.Logger.SetLevel(logrus.InfoLevel)
}
}

func loadNameservers(app *app.App, args []string) {
flagNameservers := k.Strings("nameserver")
app.Logger.WithField("flagNameservers", flagNameservers).Debug("Nameservers from -n flag")

unparsedNameservers, qt, qc, qn := loadUnparsedArgs(args)
app.Logger.WithField("unparsedNameservers", unparsedNameservers).Debug("Nameservers from unparsed arguments")

if len(flagNameservers) > 0 {
app.QueryFlags.Nameservers = flagNameservers
} else {
app.QueryFlags.Nameservers = unparsedNameservers
}

app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, qt...)
app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, qc...)
app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...)

app.Logger.WithField("finalNameservers", app.QueryFlags.Nameservers).Debug("Final nameservers")
}

func resolveQueries(app *app.App) ([]resolvers.Response, []error) {
var responses []resolvers.Response
var responseErrors []error

for _, q := range app.Questions {
for _, rslv := range app.Resolvers {
resp, err := rslv.Lookup(q)
Expand All @@ -163,33 +170,37 @@ func main() {
}
}

// Output results
if app.QueryFlags.ShowJSON {
jsonOutput := struct {
Responses []resolvers.Response `json:"responses,omitempty"`
Error string `json:"error,omitempty"`
}{
Responses: responses,
}

if len(responseErrors) > 0 {
jsonOutput.Error = responseErrors[0].Error()
}
return responses, responseErrors
}

jsonData, err := json.MarshalIndent(jsonOutput, "", " ")
if err != nil {
app.Logger.WithError(err).Error("Error marshaling JSON")
app.Logger.Exit(1)
}
fmt.Println(string(jsonData))
func outputResults(app *app.App, responses []resolvers.Response, responseErrors []error) {
if app.QueryFlags.ShowJSON {
outputJSON(responses, responseErrors)
} else {
if len(responseErrors) > 0 {
app.Logger.WithError(responseErrors[0]).Error("Error looking up DNS records")
app.Logger.Exit(9)
}
app.Output(responses)
}
}

// Quitting.
app.Logger.Exit(0)
func outputJSON(responses []resolvers.Response, responseErrors []error) {
jsonOutput := struct {
Responses []resolvers.Response `json:"responses,omitempty"`
Error string `json:"error,omitempty"`
}{
Responses: responses,
}

if len(responseErrors) > 0 {
jsonOutput.Error = responseErrors[0].Error()
}

jsonData, err := json.MarshalIndent(jsonOutput, "", " ")
if err != nil {
logger.WithError(err).Error("Error marshaling JSON")
logger.Exit(1)
}
fmt.Println(string(jsonData))
}
17 changes: 8 additions & 9 deletions cmd/doggo/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@ import (
// where `@1.1.1.1` and `AAAA` are "unparsed" args.
// Returns a list of nameserver, queryTypes, queryClasses, queryNames.
func loadUnparsedArgs(args []string) ([]string, []string, []string, []string) {
var ns, qt, qc, qn []string
var nameservers, queryTypes, queryClasses, queryNames []string
for _, arg := range args {
if strings.HasPrefix(arg, "@") {
ns = append(ns, strings.Trim(arg, "@"))
} else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok {
qt = append(qt, arg)
} else if _, ok := dns.StringToClass[strings.ToUpper(arg)]; ok {
qc = append(qc, arg)
nameservers = append(nameservers, strings.TrimPrefix(arg, "@"))
} else if qt, ok := dns.StringToType[strings.ToUpper(arg)]; ok {
queryTypes = append(queryTypes, dns.TypeToString[qt])
} else if qc, ok := dns.StringToClass[strings.ToUpper(arg)]; ok {
queryClasses = append(queryClasses, dns.ClassToString[qc])
} else {
// if nothing matches, consider it's a query name.
qn = append(qn, arg)
queryNames = append(queryNames, arg)
}
}
return ns, qt, qc, qn
return nameservers, queryTypes, queryClasses, queryNames
}
Loading

0 comments on commit f1959e1

Please sign in to comment.