package main

import "net"
import "strings"
import "encoding/hex"
import "encoding/binary"

//import "time"
import "errors"
import "log"
import "os"

type DNSHeader struct {
	ID      uint16
	Flags   uint16
	QCount  uint16
	ACount  uint16
	NSCount uint16
	ARCount uint16
}

type MDNSAnswer struct {
	Header *DNSHeader
	Domain string
	Data   net.IP
}

type MDNSQuery struct {
	Header  *DNSHeader
	Domain  string
	Type    uint16
	Class   uint16
	Unicast bool
}

const AAAA uint16 = 28

func buildHeader(header DNSHeader) []byte {
	out := make([]byte, 12)
	binary.BigEndian.PutUint16(out[0:2], header.ID)
	binary.BigEndian.PutUint16(out[2:4], header.Flags)
	binary.BigEndian.PutUint16(out[4:6], header.QCount)
	binary.BigEndian.PutUint16(out[6:8], header.ACount)
	binary.BigEndian.PutUint16(out[8:10], header.NSCount)
	binary.BigEndian.PutUint16(out[10:12], header.ARCount)
	return out
}

func buildDomain(domain string) []byte {
	var out []byte
	parts := strings.Split(domain, ".")
	for _, dpart := range parts {
		ascii := []byte(dpart)
		out = append(out, byte(len(ascii)))
		out = append(out, ascii...)
	}
	out = append(out, byte(0))
	return out
}

func buildQuery(domain string) []byte {
	header := buildHeader(DNSHeader{QCount: 1})
	domainBytes := buildDomain(domain)
	payload := append(header, domainBytes...)
	payload = append(payload, []byte{0, byte(AAAA)}...) //QTYPE=AAAA
	payload = append(payload, []byte{1 << 7, 1}...)     // Unicast=true QCLASS=1
	return payload
}

func buildResponse(domain string, ip net.IP) []byte {
	header := buildHeader(DNSHeader{ACount: 1})
	domainBytes := buildDomain(domain)
	payload := append(header, domainBytes...)
	payload = append(payload, []byte{0, byte(AAAA)}...) //QTYPE=AAAA
	payload = append(payload, []byte{0, 1}...)          // FLUSH=false, RRCLASS=1
	payload = append(payload, []byte{0, 0, 0, 1}...)    // Cache for 1 second
	payload = append(payload, []byte{0, 16}...)         // IPV6 IP is 16 bytes long
	payload = append(payload, ip...)
	return payload
}

func parseShort(buff []byte) (uint16, []byte, error) {
	if len(buff) < 2 {
		return 0, nil, errors.New("buffer to short for uint16")
	}
	return binary.BigEndian.Uint16(buff[0:2]), buff[2:], nil
}

func parseHeader(buff []byte) (*DNSHeader, []byte, error) {
	// If we can't read the header return
	if len(buff) < 12 {
		return nil, nil, errors.New("buffer to short for DNS header")
	}

	// Read the header information
	header := new(DNSHeader)
	header.ID, buff, _ = parseShort(buff)
	header.Flags, buff, _ = parseShort(buff)
	header.QCount, buff, _ = parseShort(buff)
	header.ACount, buff, _ = parseShort(buff)
	header.NSCount, buff, _ = parseShort(buff)
	header.ARCount, buff, _ = parseShort(buff)

	return header, buff, nil
}

func parseDomain(buff []byte) (string, []byte, error) {
	var domain []string
	i := 0
	for i < len(buff) && buff[i] != 0 {
		size := int(buff[i])
		i += 1
		domain = append(domain, string(buff[i:i+size]))
		i += int(size)
	}
	if i >= len(buff) {
		return "", nil, errors.New("domain string was invalid")
	}
	return strings.Join(domain, "."), buff[i+1:], nil
}

func parseQuery(buff []byte) (*MDNSQuery, error) {
	out := new(MDNSQuery)
	var err error
	out.Header, buff, err = parseHeader(buff)
	if err != nil {
		return nil, err
	}
	if out.Header.QCount != 1 {
		return nil, errors.New("header indicates that this is not a query")
	}
	out.Domain, buff, err = parseDomain(buff)
	if err != nil {
		return nil, err
	}
	if len(buff) < 4 {
		return nil, errors.New("buffer too short for query")
	}
	out.Type, buff, err = parseShort(buff)
	if err != nil {
		return nil, err
	}
	tmp, buff, _ := parseShort(buff)
	if err != nil {
		return nil, err
	}
	out.Unicast = (tmp & 1) != 0
	out.Class = tmp >> 1
	return out, nil
}

func parseAnswer(buff []byte) (*MDNSAnswer, error) {
	header, buff, err := parseHeader(buff)
	if err != nil {
		return nil, err
	}
	if header.ACount != 1 {
		return nil, errors.New("header indicates that this is not an answer")
	}
	domain, buff, err := parseDomain(buff)
	if err != nil {
		return nil, err
	}
	if len(buff) < 11 {
		return nil, errors.New("buffer too short for answer")
	}
	// We want to skip some fields that we don't care about
	buff = buff[8:]
	// Now we should be at RDLENGTH which is the length of the data response
	shortLen, buff, _ := parseShort(buff)

	dataLength := int(shortLen)
	// Now read the data
	data := make([]byte, dataLength)
	copy(data, buff[:dataLength])
	out := new(MDNSAnswer)
	out.Header = header
	out.Domain = domain
	out.Data = data
	// Finally return the thing
	return out, nil
}

// Issues:
//   1) assumes "domain" is ascii and shortish
//   2) iface must be specified by user which seems wrong
//   3) If it dosn't get an answer it infinite loops
// According to the standard the port should be 5353 but I allow to it be
// anything here.
// TODO: Make this async in someway that allows the user to handle timeouts.
func mdnsQuery(logger *log.Logger, domain string, port int, iface string) (*MDNSAnswer, error) {
	ip := net.ParseIP("ff02::fb")
	addr := net.UDPAddr{ip, port, iface}
	conn, err := net.ListenUDP("udp6", &addr)
	if err != nil {
		return nil, err
	}
	defer conn.Close()
	// Send the request
	conn.WriteTo(buildQuery(domain), &addr)
	// Read the result into a UDPPacket
	for {
		buff := make([]byte, 1024)
		n, _, _ := conn.ReadFrom(buff)
		buff = buff[:n]
		response, err := parseAnswer(buff)
		if err != nil {
			logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
		} else {
			if response.Header.ACount != 1 {
				logger.Printf("Non-answer Data:\n%s", err, hex.Dump(buff))
				// This is not an answer to a oneshot mDNS query
				continue
			}
			if response.Domain == domain {
				logger.Printf("Received answer: %s -> %s\n", domain, response.Data.String())
				return response, nil
			}
		}
	}
}

func mdnsServer(logger *log.Logger, domain string, port int, selfip string) error {
	selfIP := net.ParseIP(selfip)
	ip := net.ParseIP("ff02::fb")
	addr := net.UDPAddr{IP: ip, Port: 5353}
	conn, err := net.ListenMulticastUDP("udp6", nil, &addr)
	if err != nil {
		return err
	}
	for {
		buff := make([]byte, 1024)
		n, addr, err := conn.ReadFrom(buff)
		if err != nil {
			logger.Printf("Error: %s\n", err)
			continue
		}
		buff = buff[:n]
		query, err := parseQuery(buff)
		if err != nil {
			logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
		} else {
			logger.Printf("%s wants to know %s", addr, query.Domain)
			if domain == query.Domain {
				response := buildResponse(domain, selfIP)
				_, err := conn.WriteTo(response, addr)
				logger.Printf("%s was told %s\n", addr, selfIP)
				if err != nil {
					logger.Printf("Error: %s\nData:\n%s", err, hex.Dump(buff))
				}
			}
		}
	}
}

func main() {
	flags := log.Ltime | log.Lmicroseconds | log.LUTC
	logger := log.New(os.Stdout, "", flags)
	mdnsServer(logger, "blarg.local", 5353, "fe80::508c:b874:8526:40bd")
	//for {
	//mdnsQuery(logger, "blarg.local", 5353, "eth0")
	//time.Sleep(5 * 1000 * 1000 * 1000) // Sleep 5 seconds
	//}
}
