diff --git a/dnscrypt.go b/dnscrypt.go index 8d7813d..584fe63 100644 --- a/dnscrypt.go +++ b/dnscrypt.go @@ -76,7 +76,6 @@ type CertInfo struct { type ServerInfo struct { SecretKey [32]byte // Client secret key PublicKey [32]byte // Client public key - Proto string // Protocol ("udp" or "tcp") ServerPublicKey ed25519.PublicKey // Server public key ServerAddress string // Server IP address ProviderName string // Provider name @@ -115,12 +114,6 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ServerInfo, time.Durat curve25519.ScalarBaseMult(&serverInfo.PublicKey, &serverInfo.SecretKey) // Set the provider properties - proto := c.Proto - if proto == "" { - proto = "udp" - } - - serverInfo.Proto = proto serverInfo.ServerPublicKey = stamp.ServerPk serverInfo.ServerAddress = stamp.ServerAddrStr serverInfo.ProviderName = stamp.ProviderName @@ -129,7 +122,7 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ServerInfo, time.Durat } // Fetch the certificate and validate it - certInfo, rtt, err := serverInfo.fetchCurrentDNSCryptCert(c.Timeout) + certInfo, rtt, err := serverInfo.fetchCurrentDNSCryptCert(c.Proto, c.Timeout) if err != nil { return nil, rtt, err @@ -145,7 +138,11 @@ func (c *Client) DialStamp(stamp dnsstamps.ServerStamp) (*ServerInfo, time.Durat func (c *Client) Exchange(m *dns.Msg, s *ServerInfo) (*dns.Msg, time.Duration, error) { now := time.Now() - conn, err := net.Dial(s.Proto, s.ServerAddress) + network := c.Proto + if network == "" { + network = "udp" + } + conn, err := net.Dial(network, s.ServerAddress) if err != nil { return nil, 0, err } @@ -173,7 +170,7 @@ func (c *Client) ExchangeConn(m *dns.Msg, s *ServerInfo, conn net.Conn) (*dns.Ms return nil, 0, err } - encryptedQuery, clientNonce, err := s.encrypt(query) + encryptedQuery, clientNonce, err := s.encrypt(c.Proto, query) if err != nil { return nil, 0, err } @@ -259,14 +256,14 @@ func (c *Client) adjustPayloadSize(msg *dns.Msg) { } } -func (s *ServerInfo) fetchCurrentDNSCryptCert(timeout time.Duration) (CertInfo, time.Duration, error) { +func (s *ServerInfo) fetchCurrentDNSCryptCert(proto string, timeout time.Duration) (CertInfo, time.Duration, error) { if len(s.ServerPublicKey) != ed25519.PublicKeySize { return CertInfo{}, 0, errors.New("invalid public key length") } query := new(dns.Msg) query.SetQuestion(s.ProviderName, dns.TypeTXT) - client := dns.Client{Net: s.Proto, UDPSize: uint16(maxDNSUDPPacketSize), Timeout: timeout} + client := dns.Client{Net: proto, UDPSize: uint16(maxDNSUDPPacketSize), Timeout: timeout} in, rtt, err := client.Exchange(query, s.ServerAddress) if err != nil { return CertInfo{}, 0, err @@ -306,7 +303,7 @@ func (s *ServerInfo) fetchCurrentDNSCryptCert(timeout time.Duration) (CertInfo, return certInfo, rtt, nil } -func (s *ServerInfo) encrypt(packet []byte) (encrypted []byte, clientNonce []byte, err error) { +func (s *ServerInfo) encrypt(proto string, packet []byte) (encrypted []byte, clientNonce []byte, err error) { nonce, clientNonce := make([]byte, nonceSize), make([]byte, halfNonceSize) rand.Read(clientNonce) copy(nonce, clientNonce) @@ -316,12 +313,12 @@ func (s *ServerInfo) encrypt(packet []byte) (encrypted []byte, clientNonce []byt publicKey = &s.PublicKey minQuestionSize := queryOverhead + len(packet) - if s.Proto == "udp" { - minQuestionSize = max(minUDPQuestionSize, minQuestionSize) - } else { + if proto == "tcp" { var xpad [1]byte rand.Read(xpad[:]) minQuestionSize += int(xpad[0]) + } else { + minQuestionSize = max(minUDPQuestionSize, minQuestionSize) } paddedLength := min(maxDNSUDPPacketSize, (max(minQuestionSize, queryOverhead)+63) & ^63)