Skip to content

Commit

Permalink
reimplement dns
Browse files Browse the repository at this point in the history
1. each RR type supports multiple records
2. add configuration sys.dns.hosts and sys.dns.resolv_conf
   to specify hosts and resolv.conf file paths
3. support hosts file
  • Loading branch information
findstr committed Jun 17, 2023
1 parent a4dac56 commit 45b998a
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 94 deletions.
4 changes: 2 additions & 2 deletions lualib/http/stream.lua
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ function M.accept(scheme, fd)
end

function M.connect(scheme, host, port)
local ip = dns.resolve(host, "A")
local ip = dns.lookup(host, dns.A)
assert(ip, host)
ip = format("%s:%s", ip, port)
local ip = format("%s:%s", ip, port)
local fd = scheme_connect[scheme](ip, nil, host)
if not fd then
return nil
Expand Down
2 changes: 1 addition & 1 deletion lualib/http2/stream.lua
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ local function client_dispatch(ch)
end

local function handshake_as_client(ch, host, port)
local ip = dns.resolve(host, "A")
local ip = dns.lookup(host, dns.A)
assert(ip, host)
local addr = format("%s:%s", ip, port)
local fd = tls_connect(addr, nil, host, "h2")
Expand Down
251 changes: 171 additions & 80 deletions lualib/sys/dns.lua
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
local core = require "sys.core"
local socket = require "sys.socket"
local assert = assert
local pairs = pairs
local sub = string.sub
local concat = table.concat
local pack = string.pack
local unpack = string.unpack
local format = string.format
local match = string.match
local gmatch = string.gmatch
local setmetatable = setmetatable
local timenow = core.monotonicsec
local maxinteger = math.maxinteger

local dns = {}
local A = 1
local CNAME = 5
local AAAA = 28
local name_cache = {}
local wait_coroutine = {}
local weakmt = {__mode = "kv"}
local session = 0
local dns_server
local connectfd

local timenow = core.monotonicsec
local resolv_conf = core.envget("sys.dns.resolv_conf") or "/etc/resolv.conf"
local hosts = core.envget("sys.dns.hosts") or "/etc/hosts"

local name_cache = {}
local wait_coroutine = {}

local RR_A = 1
local RR_CNAME = 5
local RR_AAAA = 28
local RR_SRV = 33

local function guesstype(str)
if str:match("^[%d%.]+$") then
return "A"
end

if str:find(":") then
return "AAAA"
end

return "NAME"
end

--[[
ID:16
Expand Down Expand Up @@ -60,11 +78,6 @@ local function QNAME(name, n)
end

local function question(name, typ)
if typ == "AAAA" then
typ = AAAA
else
typ = A
end
session = session % 65535 + 1
local ID = session
--[[ FLAG
Expand Down Expand Up @@ -132,37 +145,104 @@ local function readname(dat, i)
return concat(tbl), i
end

local function answer(dat, pos, n)
local parser = {
[RR_A] = function(dat, pos)
local d1, d2, d3, d4 = unpack(">I1I1I1I1", dat, pos)
return format("%d.%d.%d.%d", d1, d2, d3, d4)
end,
[RR_AAAA] = function(dat, pos)
local x1, x2, x3, x4, x5, x6, x7, x8 =
unpack(">I2I2I2I2I2I2I2I2", dat, pos)
return format(
"%02x:%02x:%02x:%02x:%02x:%02x:%02x:%02x",
x1, x2, x3, x4, x5, x6, x7, x8
)
end,
[RR_CNAME] = function(dat, pos)
return readname(dat, pos)
end,
[RR_SRV] = function(dat, pos)
local priority, weight, port = unpack(">I2I2I2", dat, pos)
local target = readname(dat, pos + 6)
return {
priority = priority,
weight = weight,
port = port,
target = target,
}
end
}

local newrrmt = {__index = function(t, k)
local v = {
TTL = maxinteger
}
t[k] = v
return v
end}

local answers = setmetatable({}, {__index = function(t, k)
local v = setmetatable({}, newrrmt)
t[k] = v
return v
end})

local function merge_answers()
for name, rrs in pairs(answers) do
answers[name] = nil
local dst = name_cache[name]
if not dst then
setmetatable(rrs, nil)
name_cache[name] = rrs
else
for k, v in pairs(rrs) do
dst[k] = v
end
end
end
end

local function answer(dat, start, n)
local now = timenow()
for i = 1, n do
local src, pos = readname(dat, pos)
local qtype, qclass, ttl, rdlen, pos
= unpack(">I2I2I4I2", dat, pos)
if qtype == A then
local d1, d2, d3, d4 =
unpack(">I1I1I1I1", dat, pos)
name_cache[src] = {
TTL = timenow() + ttl,
TYPE = "A",
A = format("%d.%d.%d.%d", d1, d2, d3, d4),
}
elseif qtype == CNAME then
local cname = readname(dat, pos)
name_cache[src] = {
TTL = timenow() + ttl,
TYPE = "CNAME",
CNAME = cname,
}
elseif qtype == AAAA then
local x1, x2, x3, x4, x5, x6, x7, x8 =
unpack(">I2I2I2I2I2I2I2I2", dat, pos)
name_cache[src] = {
TTL = timenow() + ttl,
TYPE = "AAAA",
AAAA = format("%x:%x:%x:%x:%x%x:%x:%x",
x1, x2, x3, x4, x5, x6, x7, x8),
}
local name, pos = readname(dat, start)
local qtype, qclass, ttl, rdlen, pos = unpack(">I2I2I4I2", dat, pos)
local parse = parser[qtype]
if parse then
local rr = answers[name][qtype]
rr[#rr + 1] = parse(dat, pos)
ttl = now + ttl
if rr.TTL > ttl then
rr.TTL = ttl
end
end
start = pos + rdlen
end
merge_answers()
end


do --parse hosts
local f<close> = io.open(hosts)
for line in f:lines() do
local ip, names = line:match("^%s*([%[%]%x%.%:]+)%s+([^#;]+)")
if not ip or not names then
goto continue
end

local typename = guesstype(ip)
if typename == "NAME" then
goto continue
end

for name in names:gmatch("%S+") do
name = name:lower()
local rr = answers[name][typename]
rr[#rr + 1] = ip
end
::continue::
end
merge_answers()
end

local function callback(msg, _)
Expand Down Expand Up @@ -206,7 +286,7 @@ end

local function connectserver()
if not dns_server then
local f = io.open("/etc/resolv.conf", "r")
local f<close> = io.open(resolv_conf, "r")
for l in f:lines() do
dns_server = l:match("^%s*nameserver%s+([^%s]+)")
if dns_server then
Expand Down Expand Up @@ -249,66 +329,77 @@ local function query(name, typ, timeout)
end
end

local function lookup(name)
local d
local function findcache(name, qtype, deep)
local now = timenow()
for i = 1, 100 do
d = name_cache[name]
if not d then
return nil, name
local rrs = name_cache[name]
if not rrs then
return nil, nil
end
if d.TTL < now then
name_cache[name] = nil
return nil, name
local rr = rrs[qtype]
if rr and rr.TTL >= now then
return rr, nil
end
if d.TYPE == "CNAME" then
name = d.CNAME
local cname = rrs[RR_CNAME]
if cname and cname.TTL >= now then
name = cname[1]
else
return d
return nil, nil
end
end
return nil, name
return nil, nil
end

local function isname(name)
local right = name:match("([%x])", #name)
if right then
return false

local function resolve(name, qtype, timeout, deep)
if deep > 100 then
return nil
end
return true
local rr, cname = findcache(name, qtype)
if not rr and not cname then
local res = query(name, qtype, timeout)
if not res then
return nil
end
rr, cname = findcache(name, qtype)
end
if cname then
return resolve(cname, qtype, timeout, deep + 1)
end
return rr
end

function dns.resolve(name, typ, timeout)
if not isname(name) then
local dns = {
A = RR_A,
AAAA = RR_AAAA,
SRV = RR_SRV,
}

function dns.lookup(name, qtype, timeout)
if guesstype(name) ~= "NAME" then
return name
end
local d , cname = lookup(name)
if not d then
for i = 1, 100 do
local res = query(cname, typ, timeout)
if not res then
return
end
d, cname = lookup(cname)
if not cname then
goto FIND
end
end
return
local rr = resolve(name, qtype, timeout, 1)
if not rr then
return nil
end
::FIND::
if typ then
return d[typ], typ
else
return d.A or d.AAAA, d.TYPE
return rr[1]
end

function dns.resolve(name, qtype, timeout)
if guesstype(name) ~= "NAME" then
return {name}
end
return resolve(name, qtype, timeout, 1)
end

function dns.server(ip)
dns_server = ip
end

dns.isname = isname
function dns.isname(name)
return guesstype(name) == "NAME"
end

return dns

41 changes: 33 additions & 8 deletions test/testdns.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,38 @@ local dns = require "sys.dns"
local testaux = require "testaux"

return function()
local ip = dns.resolve("smtp.sina.com.cn")
testaux.assertneq(ip, nil, "dns resolve ip")
local fd = socket.connect(string.format("%s:%s", ip, 25))
testaux.assertneq(fd, nil, "dns resolve ip validate")
local l = socket.readline(fd)
testaux.assertneq(l, nil, "dns resolve ip validate")
local f = l:find("220")
testaux.assertneq(f, nil, "dns resolve ip validate")
print("testA")
local ip = dns.resolve("test.silly.gotocoding.com", dns.A)
table.sort(ip, function(a, b)
return a < b
end)
testaux.asserteq(#ip, 2, "multi ip")
testaux.asserteq(ip[1], "127.0.0.1", "ip1")
testaux.asserteq(ip[2], "127.0.0.2", "ip2")
print("testAAAA")
local ip = dns.resolve("test.silly.gotocoding.com", dns.AAAA)
table.sort(ip, function(a, b)
return a < b
end)
testaux.asserteq(#ip, 2, "multi ip")
testaux.asserteq(ip[1], "00:00:00:00:00:00:00:01", "ip1")
testaux.asserteq(ip[2], "00:00:00:00:00:00:00:02", "ip2")
print("testSRV")
local ip = dns.resolve("_rpc._tcp.gotocoding.com", dns.SRV)
testaux.asserteq(#ip, 2, "multi srv")
table.sort(ip, function(a, b)
return a.priority < b.priority
end)
testaux.asserteq(ip[1].priority, 0, "srv1")
testaux.asserteq(ip[1].weight, 5, "srv1")
testaux.asserteq(ip[1].port, 5060, "srv1")
testaux.asserteq(ip[1].target, "test2.silly.gotocoding.com", "srv1")
testaux.asserteq(ip[2].priority, 1, "srv1")
testaux.asserteq(ip[2].weight, 6, "srv1")
testaux.asserteq(ip[2].port, 5061, "srv1")
testaux.asserteq(ip[2].target, "test1.silly.gotocoding.com", "srv1")
print("test guess")
testaux.asserteq(dns.isname("wwww.gotocoding.com"), true, "name")
testaux.asserteq(dns.isname("127.0.0.1"), false, "name")
end

Loading

0 comments on commit 45b998a

Please sign in to comment.