blob: 81efa7fb539f10d6db5d8c519e53732f65293f8d [file] [log] [blame]
package main
import "net"
import "strings"
import "encoding/hex"
import "encoding/binary"
//import "time"
import "errors"
import "log"
import "os"
type DNSHeader struct {
ID uint16
Flags uint16
QCount uint16
ACount uint16
NSCount uint16
ARCount uint16
}
type MDNSAnswer struct {
Header *DNSHeader
Domain string
Data net.IP
}
type MDNSQuery struct {
Header *DNSHeader
Domain string
Type uint16
Class uint16
Unicast bool
}
const AAAA uint16 = 28
func buildHeader(header DNSHeader) []byte {
out := make([]byte, 12)
binary.BigEndian.PutUint16(out[0:2], header.ID)
binary.BigEndian.PutUint16(out[2:4], header.Flags)
binary.BigEndian.PutUint16(out[4:6], header.QCount)
binary.BigEndian.PutUint16(out[6:8], header.ACount)
binary.BigEndian.PutUint16(out[8:10], header.NSCount)
binary.BigEndian.PutUint16(out[10:12], header.ARCount)
return out
}
func buildDomain(domain string) []byte {
var out []byte
parts := strings.Split(domain, ".")
for _, dpart := range parts {
ascii := []byte(dpart)
out = append(out, byte(len(ascii)))
out = append(out, ascii...)
}
out = append(out, byte(0))
return out
}
func buildQuery(domain string) []byte {
header := buildHeader(DNSHeader{QCount: 1})
domainBytes := buildDomain(domain)
payload := append(header, domainBytes...)
payload = append(payload, []byte{0, byte(AAAA)}...) //QTYPE=AAAA
payload = append(payload, []byte{1 << 7, 1}...) // Unicast=true QCLASS=1
return payload
}
func buildResponse(domain string, ip net.IP) []byte {
header := buildHeader(DNSHeader{ACount: 1})
domainBytes := buildDomain(domain)
payload := append(header, domainBytes...)
payload = append(payload, []byte{0, byte(AAAA)}...) //QTYPE=AAAA
payload = append(payload, []byte{0, 1}...) // FLUSH=false, RRCLASS=1
payload = append(payload, []byte{0, 0, 0, 1}...) // Cache for 1 second
payload = append(payload, []byte{0, 16}...) // IPV6 IP is 16 bytes long
payload = append(payload, ip...)
return payload
}
func parseShort(buff []byte) (uint16, []byte, error) {
if len(buff) < 2 {
return 0, nil, errors.New("buffer to short for uint16")
}
return binary.BigEndian.Uint16(buff[0:2]), buff[2:], nil
}
func parseHeader(buff []byte) (*DNSHeader, []byte, error) {
// If we can't read the header return
if len(buff) < 12 {
return nil, nil, errors.New("buffer to short for DNS header")
}
// Read the header information
header := new(DNSHeader)
header.ID, buff, _ = parseShort(buff)
header.Flags, buff, _ = parseShort(buff)
header.QCount, buff, _ = parseShort(buff)
header.ACount, buff, _ = parseShort(buff)
header.NSCount, buff, _ = parseShort(buff)
header.ARCount, buff, _ = parseShort(buff)
return header, buff, nil
}
func parseDomain(buff []byte) (string, []byte, error) {
var domain []string
i := 0
for i < len(buff) && buff[i] != 0 {
size := int(buff[i])
i += 1
domain = append(domain, string(buff[i:i+size]))
i += int(size)
}
if i >= len(buff) {
return "", nil, errors.New("domain string was invalid")
}
return strings.Join(domain, "."), buff[i+1:], nil
}
func parseQuery(buff []byte) (*MDNSQuery, error) {
out := new(MDNSQuery)
var err error
out.Header, buff, err = parseHeader(buff)
if err != nil {
return nil, err
}
if out.Header.QCount != 1 {
return nil, errors.New("header indicates that this is not a query")
}
out.Domain, buff, err = parseDomain(buff)
if err != nil {
return nil, err
}
if len(buff) < 4 {
return nil, errors.New("buffer too short for query")
}
out.Type, buff, err = parseShort(buff)
if err != nil {
return nil, err
}
tmp, buff, _ := parseShort(buff)
if err != nil {
return nil, err
}
out.Unicast = (tmp & 1) != 0
out.Class = tmp >> 1
return out, nil
}
func parseAnswer(buff []byte) (*MDNSAnswer, error) {
header, buff, err := parseHeader(buff)
if err != nil {
return nil, err
}
if header.ACount != 1 {
return nil, errors.New("header indicates that this is not an answer")
}
domain, buff, err := parseDomain(buff)
if err != nil {
return nil, err
}
if len(buff) < 11 {
return nil, errors.New("buffer too short for answer")
}
// We want to skip some fields that we don't care about
buff = buff[8:]
// Now we should be at RDLENGTH which is the length of the data response
shortLen, buff, _ := parseShort(buff)
dataLength := int(shortLen)
// Now read the data
data := make([]byte, dataLength)
copy(data, buff[:dataLength])
out := new(MDNSAnswer)
out.Header = header
out.Domain = domain
out.Data = data
// Finally return the thing
return out, nil
}
// Issues:
// 1) assumes "domain" is ascii and shortish
// 2) iface must be specified by user which seems wrong
// 3) If it dosn't get an answer it infinite loops
// According to the standard the port should be 5353 but I allow to it be
// anything here.
// TODO: Make this async in someway that allows the user to handle timeouts.
func mdnsQuery(logger *log.Logger, domain string, port int, iface string) (*MDNSAnswer, error) {
ip := net.ParseIP("ff02::fb")
addr := net.UDPAddr{ip, port, iface}
conn, err := net.ListenUDP("udp6", &addr)
if err != nil {
return nil, err
}
defer conn.Close()
// Send the request
conn.WriteTo(buildQuery(domain), &addr)
// Read the result into a UDPPacket
for {
buff := make([]byte, 1024)
n, _, _ := conn.ReadFrom(buff)
buff = buff[:n]
response, err := parseAnswer(buff)
if err != nil {
logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
} else {
if response.Header.ACount != 1 {
logger.Printf("Non-answer Data:\n%s", err, hex.Dump(buff))
// This is not an answer to a oneshot mDNS query
continue
}
if response.Domain == domain {
logger.Printf("Received answer: %s -> %s\n", domain, response.Data.String())
return response, nil
}
}
}
}
func mdnsServer(logger *log.Logger, domain string, port int, selfip string) error {
selfIP := net.ParseIP(selfip)
ip := net.ParseIP("ff02::fb")
addr := net.UDPAddr{IP: ip, Port: 5353}
conn, err := net.ListenMulticastUDP("udp6", nil, &addr)
if err != nil {
return err
}
for {
buff := make([]byte, 1024)
n, addr, err := conn.ReadFrom(buff)
if err != nil {
logger.Printf("Error: %s\n", err)
continue
}
buff = buff[:n]
query, err := parseQuery(buff)
if err != nil {
logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
} else {
logger.Printf("%s wants to know %s", addr, query.Domain)
if domain == query.Domain {
response := buildResponse(domain, selfIP)
_, err := conn.WriteTo(response, addr)
logger.Printf("%s was told %s\n", addr, selfIP)
if err != nil {
logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
}
}
}
}
}
func main() {
flags := log.Ltime | log.Lmicroseconds | log.LUTC
logger := log.New(os.Stdout, "", flags)
mdnsServer(logger, "blarg.local", 5353, "fe80::508c:b874:8526:40bd")
//for {
//mdnsQuery(logger, "blarg.local", 5353, "eth0")
//time.Sleep(5 * 1000 * 1000 * 1000) // Sleep 5 seconds
//}
}