From a6447cfac39b45c93ffae33a85573bdc64b5c8ad Mon Sep 17 00:00:00 2001 From: Karan Sharma Date: Tue, 2 Jul 2024 13:57:32 +0530 Subject: [PATCH] feat: Refactor DNS resolvers for concurrent lookups This commit significantly improves the performance of DNS lookups by implementing concurrent query execution across all resolver types. It also refactors the common lookup logic into a shared function to reduce code duplication and improve maintainability. Performance improvements: - Reduced lookup time for multiple queries by ~78% (from 1.356s to 0.297s for a sample query with 10 record types) - Improved CPU utilization (from 1% to 4%) indicating better resource use --- cmd/doggo/cli.go | 44 +++++++++++++++++-------------- pkg/resolvers/classic.go | 9 +++++-- pkg/resolvers/common.go | 42 ++++++++++++++++++++++++++++++ pkg/resolvers/dnscrypt.go | 10 ++++--- pkg/resolvers/doh.go | 9 +++++-- pkg/resolvers/doq.go | 7 ++++- pkg/resolvers/resolver.go | 2 +- web/handlers.go | 55 +++++++++++++++++++++++++++++---------- 8 files changed, 136 insertions(+), 42 deletions(-) create mode 100644 pkg/resolvers/common.go diff --git a/cmd/doggo/cli.go b/cmd/doggo/cli.go index 4933243..685a38b 100644 --- a/cmd/doggo/cli.go +++ b/cmd/doggo/cli.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "sync" "time" "github.com/knadh/koanf/providers/posflag" @@ -85,6 +86,13 @@ func main() { os.Exit(0) } + var ( + wg sync.WaitGroup + mu sync.Mutex + allResponses []resolvers.Response + allErrors []error + ) + queryFlags := resolvers.QueryFlags{ AA: k.Bool("aa"), AD: k.Bool("ad"), @@ -94,9 +102,24 @@ func main() { DO: k.Bool("do"), } - responses, responseErrors := resolveQueries(&app, queryFlags) + for _, resolver := range app.Resolvers { + wg.Add(1) + go func(r resolvers.Resolver) { + defer wg.Done() + responses, err := r.Lookup(app.Questions, queryFlags) + mu.Lock() + if err != nil { + allErrors = append(allErrors, err) + } else { + allResponses = append(allResponses, responses...) + } + mu.Unlock() + }(resolver) + } - outputResults(&app, responses, responseErrors) + wg.Wait() + + outputResults(&app, allResponses, allErrors) os.Exit(0) } @@ -166,23 +189,6 @@ func loadNameservers(app *app.App, args []string) { app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...) } -func resolveQueries(app *app.App, flags resolvers.QueryFlags) ([]resolvers.Response, []error) { - var responses []resolvers.Response - var responseErrors []error - - for _, q := range app.Questions { - for _, rslv := range app.Resolvers { - resp, err := rslv.Lookup(q, flags) - if err != nil { - responseErrors = append(responseErrors, err) - } - responses = append(responses, resp) - } - } - - return responses, responseErrors -} - func outputResults(app *app.App, responses []resolvers.Response, responseErrors []error) { if app.QueryFlags.ShowJSON { outputJSON(app.Logger, responses, responseErrors) diff --git a/pkg/resolvers/classic.go b/pkg/resolvers/classic.go index 1c6b93f..832f1a5 100644 --- a/pkg/resolvers/classic.go +++ b/pkg/resolvers/classic.go @@ -59,7 +59,7 @@ func NewClassicResolver(server string, classicOpts ClassicResolverOpts, resolver // Lookup takes a dns.Question and sends them to DNS Server. // It parses the Response from the server in a custom output format. -func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) { +func (r *ClassicResolver) query(question dns.Question, flags QueryFlags) (Response, error) { var ( rsp Response messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList) @@ -93,7 +93,7 @@ func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Respo r.client.Net = "tcp" } r.resolverOptions.Logger.Debug("Response truncated; retrying now", "protocol", r.client.Net) - return r.Lookup(question, flags) + return r.query(question, flags) } // Pack questions in output. @@ -119,3 +119,8 @@ func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Respo } return rsp, nil } + +// Lookup implements the Resolver interface +func (r *ClassicResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) { + return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger) +} diff --git a/pkg/resolvers/common.go b/pkg/resolvers/common.go new file mode 100644 index 0000000..3068538 --- /dev/null +++ b/pkg/resolvers/common.go @@ -0,0 +1,42 @@ +package resolvers + +import ( + "log/slog" + "sync" + + "github.com/miekg/dns" +) + +// QueryFunc represents the signature of a query function +type QueryFunc func(question dns.Question, flags QueryFlags) (Response, error) + +// ConcurrentLookup performs concurrent DNS lookups +func ConcurrentLookup(questions []dns.Question, flags QueryFlags, queryFunc QueryFunc, logger *slog.Logger) ([]Response, error) { + var wg sync.WaitGroup + responses := make([]Response, len(questions)) + errors := make([]error, len(questions)) + + for i, q := range questions { + wg.Add(1) + go func(i int, q dns.Question) { + defer wg.Done() + resp, err := queryFunc(q, flags) + responses[i] = resp + errors[i] = err + }(i, q) + } + + wg.Wait() + + // Collect non-nil responses and handle errors + var validResponses []Response + for i, resp := range responses { + if errors[i] != nil { + logger.Error("error in lookup", "error", errors[i]) + } else { + validResponses = append(validResponses, resp) + } + } + + return validResponses, nil +} diff --git a/pkg/resolvers/dnscrypt.go b/pkg/resolvers/dnscrypt.go index 0905858..2e27fcd 100644 --- a/pkg/resolvers/dnscrypt.go +++ b/pkg/resolvers/dnscrypt.go @@ -40,9 +40,13 @@ func NewDNSCryptResolver(server string, dnscryptOpts DNSCryptResolverOpts, resol }, nil } -// Lookup takes a dns.Question and sends them to DNS Server. -// It parses the Response from the server in a custom output format. -func (r *DNSCryptResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) { +// Lookup implements the Resolver interface +func (r *DNSCryptResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) { + return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger) +} + +// query performs a single DNS query +func (r *DNSCryptResolver) query(question dns.Question, flags QueryFlags) (Response, error) { var ( rsp Response messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList) diff --git a/pkg/resolvers/doh.go b/pkg/resolvers/doh.go index d595190..b88b550 100644 --- a/pkg/resolvers/doh.go +++ b/pkg/resolvers/doh.go @@ -46,9 +46,9 @@ func NewDOHResolver(server string, resolverOpts Options) (Resolver, error) { }, nil } -// Lookup takes a dns.Question and sends them to DNS Server. +// query takes a dns.Question and sends them to DNS Server. // It parses the Response from the server in a custom output format. -func (r *DOHResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) { +func (r *DOHResolver) query(question dns.Question, flags QueryFlags) (Response, error) { var ( rsp Response messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList) @@ -123,3 +123,8 @@ func (r *DOHResolver) Lookup(question dns.Question, flags QueryFlags) (Response, } return rsp, nil } + +// Lookup implements the Resolver interface +func (r *DOHResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) { + return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger) +} diff --git a/pkg/resolvers/doq.go b/pkg/resolvers/doq.go index e4d8bc1..39e72a0 100644 --- a/pkg/resolvers/doq.go +++ b/pkg/resolvers/doq.go @@ -34,9 +34,14 @@ func NewDOQResolver(server string, resolverOpts Options) (Resolver, error) { }, nil } +// Lookup implements the Resolver interface +func (r *DOQResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) { + return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger) +} + // Lookup takes a dns.Question and sends them to DNS Server. // It parses the Response from the server in a custom output format. -func (r *DOQResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) { +func (r *DOQResolver) query(question dns.Question, flags QueryFlags) (Response, error) { var ( rsp Response messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList) diff --git a/pkg/resolvers/resolver.go b/pkg/resolvers/resolver.go index 2c83900..b6f55ba 100644 --- a/pkg/resolvers/resolver.go +++ b/pkg/resolvers/resolver.go @@ -28,7 +28,7 @@ type Options struct { // Client. Different types of providers can load // a DNS Resolver satisfying this interface. type Resolver interface { - Lookup(dns.Question, QueryFlags) (Response, error) + Lookup([]dns.Question, QueryFlags) ([]Response, error) } // Response represents a custom output format diff --git a/web/handlers.go b/web/handlers.go index 099a6a6..4904e49 100644 --- a/web/handlers.go +++ b/web/handlers.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/mr-karan/doggo/internal/app" @@ -41,14 +42,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err != nil { app.Logger.Error("error reading request body", "error", err) - sendErrorResponse(w, fmt.Sprintf("Invalid JSON payload"), http.StatusBadRequest, nil) + sendErrorResponse(w, "Invalid JSON payload", http.StatusBadRequest, nil) return } // Prepare query flags. var qFlags models.QueryFlags if err := json.Unmarshal(b, &qFlags); err != nil { app.Logger.Error("error unmarshalling payload", "error", err) - sendErrorResponse(w, fmt.Sprintf("Invalid JSON payload"), http.StatusBadRequest, nil) + sendErrorResponse(w, "Invalid JSON payload", http.StatusBadRequest, nil) return } @@ -60,14 +61,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) { app.PrepareQuestions() if len(app.Questions) == 0 { - sendErrorResponse(w, fmt.Sprintf("Missing field `query`."), http.StatusBadRequest, nil) + sendErrorResponse(w, "Missing field `query`.", http.StatusBadRequest, nil) return } // Load Nameservers. if err := app.LoadNameservers(); err != nil { app.Logger.Error("error loading nameservers", "error", err) - sendErrorResponse(w, fmt.Sprintf("Error looking up for records."), http.StatusInternalServerError, nil) + sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil) return } @@ -96,19 +97,45 @@ func handleLookup(w http.ResponseWriter, r *http.Request) { RD: true, } - var responses []resolvers.Response - for _, q := range app.Questions { - for _, rslv := range app.Resolvers { - resp, err := rslv.Lookup(q, queryFlags) + // ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + // defer cancel() + + var ( + wg sync.WaitGroup + mu sync.Mutex + allResponses []resolvers.Response + allErrors []error + ) + + for _, resolver := range app.Resolvers { + wg.Add(1) + go func(r resolvers.Resolver) { + defer wg.Done() + responses, err := r.Lookup(app.Questions, queryFlags) + mu.Lock() if err != nil { - app.Logger.Error("error looking up DNS records", "error", err) - sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil) - return + allErrors = append(allErrors, err) + } else { + allResponses = append(allResponses, responses...) } - responses = append(responses, resp) - } + mu.Unlock() + }(resolver) + } + + wg.Wait() + + if len(allErrors) > 0 { + app.Logger.Error("errors looking up DNS records", "errors", allErrors) + sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil) + return } - sendResponse(w, http.StatusOK, responses) + + if len(allResponses) == 0 { + sendErrorResponse(w, "No records found.", http.StatusNotFound, nil) + return + } + + sendResponse(w, http.StatusOK, allResponses) } // wrap is a middleware that wraps HTTP handlers and injects the "app" context.