2010-11-13

A First-Principles DNS Client

Ever wondered how nslookup works? I'm a protocol junkie, and Domain Name (RFC 1035) has it all. It's a simple protocol with a highly structured message format, so the rich machinery of OCaml makes easy to build a simple but relatively complete client. The code is long for a blog post (a bit less than 500 lines), so I'll show the highlights and leave the rest for you to use as you see fit. Since I'm not constrained by presenting a literate bottom-up program, I'll start with a couple of examples. A simple query is straightforward:

# query_dns "10.0.0.1" (query ~q_type:`A 0 "www.nytimes.com");;
- : dns_record =
{id = 0; detail = 33152;
 question = [{q_name = "www.nytimes.com"; q_type = `A; q_class = `IN}];
 answer =
  [{rr_name = "www.nytimes.com"; rr_type = `A; rr_class = `IN; rr_ttl = 120l;
    rr_rdata = `Address 164.107.65.0}];
 authority =
  [{rr_name = "www.nytimes.com"; rr_type = `NS; rr_class = `IN; rr_ttl = 60l;
    rr_rdata = `Domain "nss1.sea1.nytimes.com"};
   {rr_name = "www.nytimes.com"; rr_type = `NS; rr_class = `IN; rr_ttl = 60l;
    rr_rdata = `Domain "nss1.lga2.nytimes.com"}];
 additional =
  [{rr_name = "nss1.lga2.nytimes.com"; rr_type = `A; rr_class = `IN;
    rr_ttl = 60l; rr_rdata = `Address 164.107.65.0}]}

(I installed a top-level pretty printer to show the inet_addresses). As another example, here's how to query the servers to which to send mail for a domain:

let mail_servers server domain =
  let res = query_dns server (query ~q_type:`MX 0 domain) in
  List.sort (fun (p, _) (p', _) -> compare p p')
    (List.map (function { rr_rdata = `Exchange (p, d); _ } -> (p, d))
      (List.filter (fun { rr_type; _ } -> rr_type = `MX)
        res.answer))

And here it is in action:

# mail_servers "10.0.0.1" "google.com";;
- : (int * domain_name) list =
[(100, "google.com.s9a1.psmtp.com"); (200, "google.com.s9a2.psmtp.com");
 (300, "google.com.s9b1.psmtp.com"); (400, "google.com.s9b2.psmtp.com")]

In total, I think I've put in 20 hours to this little project, including this write-up. The networking part is very simple, though it took me a number of tries to get it right (I'm rusty):

let dns_port  = (Unix.getservbyname "domain" "udp").Unix.s_port

let query_dns addr q =
  let len = 4096 in
  let buf  = String.create len in
  let msg  = Writer.run (write_dns_record q) in
  let sock = Unix.socket Unix.PF_INET Unix.SOCK_DGRAM 0 in
  let peer = Unix.ADDR_INET (Unix.inet_addr_of_string addr, dns_port) in
  unwind ~protect:Unix.close (fun sock ->
    Unix.setsockopt_float sock Unix.SO_RCVTIMEO  1.;
    let _ = Unix.sendto sock msg 0 (String.length msg) [] peer in
    let cnt, _ = Unix.recvfrom sock buf 0 len [] in
    match Parser.run parse_dns_record (String.sub buf 0 cnt) with
    | Some dns -> dns
    | None     -> failwith "Parse error"
  ) sock

It is important to set the receive timeout for the datagram socket, otherwise it blocks trying to read too many bytes. The function takes the address of the Domain Name server to query and a dns_record containing the query. It is serialized, sent to the server, a response is received and parsed. For that, I use a parsing monad and a writer monoid to structure the handling of messages (the slogan is "parsing is monadic, pretty-printing is monoidal"):

module Parser = struct
  type cursor = string * int

  type 'a parser = Parser of (cursor -> 'a option * cursor)

  include Monad (struct
    type 'a t = 'a parser
    let return x = Parser (fun cur -> (Some x, cur))
    let bind f (Parser p) = Parser (fun cur ->
      match p cur with
      | (None  , _  ) -> (None, cur)
      | (Some x, cur) -> let Parser q = f x in q cur)
    let fail = Parser (fun cur -> (None, cur))
  end)

This is a simple deterministic parsing monad using option instead of list. The only quirk is that Domain Name messages use back-references for data compression, so I need full random-access to the message buffer with positioning information in the form of a cursor. Running a parser is done by providing it with a buffer:

  let run (Parser p) str = let (res, _) = p (str, 0) in res

The use of positioning information means that the parser is also a state monad:

  let tell     = Parser (fun (_, pos as cur) -> (Some pos, cur))
  let seek off = Parser (fun (str, _ as cur) ->
    if 0 <= off && off <= String.length str
    then (Some (), (str, off))
    else (None, cur) )

  (* … code omitted … *)
end

I deal with low-level entities like bytes (properly octets, but let's not quibble) and big-endian integers. Parsing these is straightforward:

type int16  = int
type byte   = int
type bytes  = string

let byte : byte Parser.t = Parser.(fmap int_of_char next)

let bytes cnt : bytes Parser.t = Parser.take cnt

let int16 : int16 Parser.t = 
  let open Parser in
  byte >>= fun hi ->
  byte >>= fun lo ->
  return ((hi lsl 8) lor lo)

let int32 : int32 Parser.t =
  let open Parser in
  let (<|) m n = Int32.logor (Int32.shift_left m 8) (Int32.of_int n) in
  byte >>= fun b0 ->
  byte >>= fun b1 ->
  byte >>= fun b2 ->
  byte >>= fun b3 ->
  return (((Int32.of_int b0 <| b1) <| b2) <| b3)

By the way, I love the new syntax for modules in OCaml 3.12! Now labels are length-prefixed sequences of bytes, and domain names are sequences of labels:

type label        = bytes
type domain_name  = bytes

let write_bytes (str : bytes) =
  let len = String.length str in
  if len > 255 then invalid_arg "write_bytes" else
  Writer.(byte len ++ bytes str)

let parse_bytes = Parser.(byte >>= bytes)

let write_label (lbl : label) =
  let len = String.length lbl in
  if len > 63 then invalid_arg "write_label" else
  Writer.(byte len ++ bytes lbl)

let parse_label : label Parser.t =
  let open Parser in
  ensure (byte >>= fun n -> guard (0 < n && n < 64)) >> byte >>= bytes

The RFC specifies that labels are limited to 63 octets in length, and domain names can use back-references for reducing the size of a packet. Thus a domain name is terminated by the null label (corresponding to the top-level domain "."), or by a pointer to another label in the packet:

let re_dot = Str.regexp_string "."

let split_domain (name : domain_name) = Str.split re_dot name

let write_domain_name (name : domain_name) =
  let labels = split_domain name in
  Writer.(sequence write_label labels ++ byte 0)

let parse_domain_name : domain_name Parser.t =
 let open Parser in
 let rec labels () =
    sequence parse_label       >>= fun ls  ->
    byte                       >>= fun n   ->
    if n = 0 then return ls else
    guard (n land 0xc0 = 0xc0) >>
    byte                       >>= fun m   ->
    let off = ((n land 0x3f) lsl 8) lor m in
    tell                       >>= fun pos ->
    seek off                   >>
    labels ()                  >>= fun ls' ->
    seek pos                   >>
    return (ls @ ls')
  in fmap (String.concat ".") $ labels ()

Writing a domain name is straightforward, as I don't do compression. Reading a domain name is done recursively. Note that parse_label fails (via ensure) whenever the label is empty or overlong. The first case corresponds to the top-level domain, which ends the sequence; the second case corresponds to a 16-bit pointer which is signalled by its two most-significant bits set. To read the back-reference I record the position in the packet via tell, seek to the referenced position, parse the label sequence recursively (which might trigger yet more back-references), restore the original position and return all the results. For example, consider the following complete response (don't mind the names yet):

id        0000  \000\000
detail    0002  \132\000
qdcount   0004  \000\001
ancount   0006  \000\002
nscount   0008  \000\000
arcount   0010  \000\000
q_name    0012  \009_services\007_dns-sd\004_udp
          0035  \005local\000
q_type    0042  \000\012
q_class   0044  \000\001
rr_name   0046  \192\012
rr_type   0048  \000\012
rr_class  0050  \000\001
rr_ttl    0052  \000\000\000\010
rr_dlen   0056  \000\020
rr_rdata  0058  \012_workstation
          0071  \004_tcp\192\035
rr_name   0078  \192\012
rr_type   0080  \000\012
rr_class  0082  \000\001
rr_ttl    0084  \000\000\000\010
rr_dlen   0088  \000\008
rr_rdata  0090  \005_http\192\071
          0098

The domain name at offset 0090 (_http) refers back to 0071 (_tcp) which in turn refers back to 0035 (local), which spells _http._tcp.local. Now Domain Name is basically a distributed database system which replies to queries with responses containing zero or more resources. The types of resources is spelled out in detail in the RFC:

type rr_type = [
| `A | `NS | `MD | `MF | `CNAME | `SOA | `MB | `MG
| `MR | `NULL | `WKS | `PTR | `HINFO | `MINFO | `MX | `TXT
]

Why polymorphic variants? Because query types are a superset of resource types:

type q_type = [ rr_type | `AXFR | `MAILB | `MAILA | `ANY ]

This lets me reuse the code for converting back and forth between labels and protocol integers. In turn, each standard resource has its particular format:

type resource = [
| `Hostinfo  of string * string
| `Domain    of domain_name
| `Mailbox   of domain_name * domain_name
| `Exchange  of int * domain_name
| `Data      of bytes
| `Text      of bytes list
| `Authority of domain_name * domain_name * int32 * int32 * int32 * int32 * int32
| `Address   of Unix.inet_addr
| `Services  of int32 * byte * bytes
]

The types are dictated by the protocol, and they in turn dictate how to format and write resource data. The first is straightforward:

let write_resource =
  let open Writer in function
  | `Hostinfo (cpu, os) ->
       write_bytes       cpu
    ++ write_bytes       os
  | `Domain name ->
       write_domain_name name
  | `Mailbox  (rmbx, embx) ->
       write_domain_name rmbx
    ++ write_domain_name embx
  | `Exchange (pref, exch) ->
       int16             pref
    ++ write_domain_name exch
  | `Data data ->
       bytes             data
  | `Text texts ->
    sequence write_bytes texts
  | `Authority (mname, rname, serial, refresh, retry, expire, minimum) ->
       write_domain_name mname
    ++ write_domain_name rname
    ++ int32             serial
    ++ int32             refresh
    ++ int32             retry
    ++ int32             expire
    ++ int32             minimum
  | `Address addr ->
       int32  (Obj.magic addr : int32)
  | `Services (addr, proto, bmap) ->
       int32             addr
    ++ byte              proto
    ++ bytes             bmap

(the use of Obj.magic to coerce addresses back and forth is relatively unperilous). Parsing requires knowing the resource type to decode the payload. Note that Writer.sequence formats a list by formatting each element in turn while Parser.sequence applies repeatedly a parser until it fails and returns the list of intermediate parsers:

let parse_resource rr_type rr_dlen =
  let open Parser in match rr_type with
  | `HINFO ->
    parse_bytes                >>= fun cpu ->
    parse_bytes                >>= fun os  ->
    return (`Hostinfo (cpu, os))
  | `MB | `MD | `MF | `MG | `MR | `NS
  | `CNAME | `PTR ->
    parse_domain_name          >>= fun name ->
    return (`Domain name)
  | `MINFO ->
    parse_domain_name          >>= fun rmailbx ->
    parse_domain_name          >>= fun emailbx ->
    return (`Mailbox (rmailbx, emailbx))
  | `MX ->
    int16                      >>= fun preference ->
    parse_domain_name          >>= fun exchange   ->
    return (`Exchange (preference, exchange))
  | `NULL ->
    bytes rr_dlen              >>= fun data ->
    return (`Data data)
  | `TXT ->
    sequence (byte >>= bytes)  >>= fun texts ->
    return (`Text texts)
  | `SOA ->
    parse_domain_name          >>= fun mname   ->
    parse_domain_name          >>= fun rname   ->
    int32                      >>= fun serial  ->
    int32                      >>= fun refresh ->
    int32                      >>= fun retry   ->
    int32                      >>= fun expire  ->
    int32                      >>= fun minimum ->
    return (`Authority (mname, rname, serial, refresh, retry, expire, minimum))
  | `A ->
    int32                      >>= fun addr ->
    return (`Address  (Obj.magic addr : Unix.inet_addr))
  | `WKS ->
    int32                      >>= fun addr   ->
    byte                       >>= fun proto  ->
    bytes (rr_dlen-5)          >>= fun bitmap ->
    return (`Services (addr, proto, bitmap))

There's nothing to it, really, as monadic notation makes the code read straight. A resource is described by a rsrc_record, which can be written and parsed fairly simply:

type rsrc_record = {
  rr_name  : domain_name;
  rr_type  : rr_type;
  rr_class : rr_class;
  rr_ttl   : int32;
  rr_rdata : resource;
}

let write_rsrc_record r =
  let open Writer in
  let rr_rdata = run (write_resource r.rr_rdata) in
     write_domain_name      r.rr_name
  ++ int16 (int_of_rr_type  r.rr_type)
  ++ int16 (int_of_rr_class r.rr_class)
  ++ int32                  r.rr_ttl
  ++ int16     (String.length rr_rdata)
  ++ bytes                    rr_rdata

let parse_rsrc_record =
  let open Parser in
  parse_domain_name              >>= fun rr_name  ->
  fmap rr_type_of_int  int16     >>= fun rr_type  ->
  fmap rr_class_of_int int16     >>= fun rr_class ->
  int32                          >>= fun rr_ttl   ->
  int16                          >>= fun rr_dlen  ->
  parse_resource rr_type rr_dlen >>= fun rr_rdata ->
  return { rr_name; rr_type; rr_class; rr_ttl; rr_rdata; }

Note that since the resource payload is prefixed by its length I have to write it out-of-band to know its length. This is the only instance of buffer copying in the implementation. A question is handled similarly:

type question = {
  q_name  : domain_name;
  q_type  : q_type;
  q_class : q_class;
}
 
let write_question q =
  let open Writer in
     write_domain_name     q.q_name
  ++ int16 (int_of_q_type  q.q_type )
  ++ int16 (int_of_q_class q.q_class)

let parse_question =
  let open Parser in
  parse_domain_name         >>= fun q_name  ->
  fmap q_type_of_int  int16 >>= fun q_type  ->
  fmap q_class_of_int int16 >>= fun q_class ->
  return { q_name; q_type; q_class; }

A Domain Name record has a serial number, a bit field with options and response codes and a sequence of question followed by various sequences of resources as answers:

type dns_record = {
  id         : int16;
  detail     : int16;
  question   : question    list;
  answer     : rsrc_record list;
  authority  : rsrc_record list;
  additional : rsrc_record list;
}

let write_dns_record d =
  let open Writer in
     int16                      d.id
  ++ int16                      d.detail
  ++ int16 (List.length         d.question  )
  ++ int16 (List.length         d.answer    )
  ++ int16 (List.length         d.authority )
  ++ int16 (List.length         d.additional)
  ++ sequence write_question    d.question
  ++ sequence write_rsrc_record d.answer
  ++ sequence write_rsrc_record d.authority
  ++ sequence write_rsrc_record d.additional

let parse_dns_record =
  let open Parser in
  int16                            >>= fun id         ->
  int16                            >>= fun detail     ->
  int16                            >>= fun qdcount    ->
  int16                            >>= fun ancount    ->
  int16                            >>= fun nscount    ->
  int16                            >>= fun arcount    ->
  repeat qdcount parse_question    >>= fun question   ->
  repeat ancount parse_rsrc_record >>= fun answer     ->
  repeat nscount parse_rsrc_record >>= fun authority  ->
  repeat arcount parse_rsrc_record >>= fun additional ->
  return { id; detail; question; answer; authority; additional; }

Very high-level, straight-line code thanks to the structuring power of the algebraic structures! A query involves asking for a resource identified by domain name. If you analyze the bit field, you'll notice that I ask for a recursive query:

let query ?(q_type=`A) id q_name =
  if id land 0xffff <> id then invalid_arg "query" else {
  id;
  detail     = 0b0_0000_0010_000_0000;
  question   = [ { q_name; q_type; q_class = `IN; } ];
  answer     = [];
  authority  = [];
  additional = []; }

That's it! Grab the code and play with it; I've put it in the Public Domain so that you can use it for whatever purpose you want.