Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • ai3/tools/acmeserver
  • godog/acmeserver
  • svp-bot/acmeserver
3 results
Show changes
Showing
with 4029 additions and 3140 deletions
// +build !go1.11 !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd
package dns
import "net"
const supportsReusePort = false
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
if reuseport {
// TODO(tmthrgd): return an error?
}
return net.Listen(network, addr)
}
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
if reuseport {
// TODO(tmthrgd): return an error?
}
return net.ListenPacket(network, addr)
}
// +build go1.11
// +build aix darwin dragonfly freebsd linux netbsd openbsd
package dns
import (
"context"
"net"
"syscall"
"golang.org/x/sys/unix"
)
const supportsReusePort = true
func reuseportControl(network, address string, c syscall.RawConn) error {
var opErr error
err := c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return opErr
}
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
var lc net.ListenConfig
if reuseport {
lc.Control = reuseportControl
}
return lc.Listen(context.Background(), network, addr)
}
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
var lc net.ListenConfig
if reuseport {
lc.Control = reuseportControl
}
return lc.ListenPacket(context.Background(), network, addr)
}
......@@ -9,21 +9,41 @@
package dns
//go:generate go run msg_generate.go
//go:generate go run compress_generate.go
import (
crand "crypto/rand"
"crypto/rand"
"encoding/binary"
"fmt"
"math/big"
"math/rand"
"strconv"
"sync"
"strings"
)
const (
maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer
maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4
// This is the maximum number of compression pointers that should occur in a
// semantically valid message. Each label in a domain name must be at least one
// octet and is separated by a period. The root label won't be represented by a
// compression pointer to a compression pointer, hence the -2 to exclude the
// smallest valid root label.
//
// It is possible to construct a valid message that has more compression pointers
// than this, and still doesn't loop, by pointing to a previous pointer. This is
// not something a well written implementation should ever do, so we leave them
// to trip the maximum compression pointer check.
maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
// This is the maximum length of a domain name in presentation format. The
// maximum wire length of a domain name is 255 octets (see above), with the
// maximum label length being 63. The wire format requires one extra byte over
// the presentation format, reducing the number of octets by 1. Each label in
// the name will be separated by a single period, with each octet in the label
// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
// length, then the final label can only be 61 octets long to not exceed the
// maximum allowed wire length.
maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
)
// Errors defined in this package.
......@@ -46,59 +66,28 @@ var (
ErrRRset error = &Error{err: "bad rrset"}
ErrSecret error = &Error{err: "no secrets defined"}
ErrShortRead error = &Error{err: "short read"}
ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
ErrTruncated error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
)
// Id by default, returns a 16 bits random number to be used as a
// message id. The random provided should be good enough. This being a
// variable the function can be reassigned to a custom function.
// For instance, to make it return a static value:
// Id by default returns a 16-bit random number to be used as a message id. The
// number is drawn from a cryptographically secure random number generator.
// This being a variable the function can be reassigned to a custom function.
// For instance, to make it return a static value for testing:
//
// dns.Id = func() uint16 { return 3 }
var Id = id
var (
idLock sync.Mutex
idRand *rand.Rand
)
// id returns a 16 bits random number to be used as a
// message id. The random provided should be good enough.
func id() uint16 {
idLock.Lock()
if idRand == nil {
// This (partially) works around
// https://github.com/golang/go/issues/11833 by only
// seeding idRand upon the first call to id.
var seed int64
var buf [8]byte
if _, err := crand.Read(buf[:]); err == nil {
seed = int64(binary.LittleEndian.Uint64(buf[:]))
} else {
seed = rand.Int63()
}
idRand = rand.New(rand.NewSource(seed))
var output uint16
err := binary.Read(rand.Reader, binary.BigEndian, &output)
if err != nil {
panic("dns: reading random id failed: " + err.Error())
}
// The call to idRand.Uint32 must be within the
// mutex lock because *rand.Rand is not safe for
// concurrent use.
//
// There is no added performance overhead to calling
// idRand.Uint32 inside a mutex lock over just
// calling rand.Uint32 as the global math/rand rng
// is internally protected by a sync.Mutex.
id := uint16(idRand.Uint32())
idLock.Unlock()
return id
return output
}
// MsgHdr is a a manually-unpacked version of (id, bits).
......@@ -151,7 +140,7 @@ var RcodeToString = map[int]string{
RcodeFormatError: "FORMERR",
RcodeServerFailure: "SERVFAIL",
RcodeNameError: "NXDOMAIN",
RcodeNotImplemented: "NOTIMPL",
RcodeNotImplemented: "NOTIMP",
RcodeRefused: "REFUSED",
RcodeYXDomain: "YXDOMAIN", // See RFC 2136
RcodeYXRrset: "YXRRSET",
......@@ -169,6 +158,39 @@ var RcodeToString = map[int]string{
RcodeBadCookie: "BADCOOKIE",
}
// compressionMap is used to allow a more efficient compression map
// to be used for internal packDomainName calls without changing the
// signature or functionality of public API.
//
// In particular, map[string]uint16 uses 25% less per-entry memory
// than does map[string]int.
type compressionMap struct {
ext map[string]int // external callers
int map[string]uint16 // internal callers
}
func (m compressionMap) valid() bool {
return m.int != nil || m.ext != nil
}
func (m compressionMap) insert(s string, pos int) {
if m.ext != nil {
m.ext[s] = pos
} else {
m.int[s] = uint16(pos)
}
}
func (m compressionMap) find(s string) (int, bool) {
if m.ext != nil {
pos, ok := m.ext[s]
return pos, ok
}
pos, ok := m.int[s]
return int(pos), ok
}
// Domain names are a sequence of counted strings
// split at the dots. They end with a zero-length string.
......@@ -177,143 +199,161 @@ var RcodeToString = map[int]string{
// map needs to hold a mapping between domain names and offsets
// pointing into msg.
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off1, _, err = packDomainName(s, msg, off, compression, compress)
return
return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
}
func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) {
// special case if msg == nil
lenmsg := 256
if msg != nil {
lenmsg = len(msg)
}
func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
// XXX: A logical copy of this function exists in IsDomainName and
// should be kept in sync with this function.
ls := len(s)
if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
return off, 0, nil
}
// If not fully qualified, error out, but only if msg == nil #ugly
switch {
case msg == nil:
if s[ls-1] != '.' {
s += "."
ls++
}
case msg != nil:
if s[ls-1] != '.' {
return lenmsg, 0, ErrFqdn
}
return off, nil
}
// If not fully qualified, error out.
if !IsFqdn(s) {
return len(msg), ErrFqdn
}
// Each dot ends a segment of the name.
// We trade each dot byte for a length byte.
// Except for escaped dots (\.), which are normal dots.
// There is also a trailing zero.
// Compression
nameoffset := -1
pointer := -1
// Emit sequence of counted strings, chopping at dots.
begin := 0
bs := []byte(s)
roBs, bsFresh, escapedDot := s, true, false
var (
begin int
compBegin int
compOff int
bs []byte
wasDot bool
)
loop:
for i := 0; i < ls; i++ {
if bs[i] == '\\' {
for j := i; j < ls-1; j++ {
bs[j] = bs[j+1]
var c byte
if bs == nil {
c = s[i]
} else {
c = bs[i]
}
switch c {
case '\\':
if off+1 > len(msg) {
return len(msg), ErrBuf
}
ls--
if off+1 > lenmsg {
return lenmsg, labels, ErrBuf
if bs == nil {
bs = []byte(s)
}
// check for \DDD
if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
bs[i] = dddToByte(bs[i:])
for j := i + 1; j < ls-2; j++ {
bs[j] = bs[j+2]
}
ls -= 2
if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) {
bs[i] = dddToByte(bs[i+1:])
copy(bs[i+1:ls-3], bs[i+4:])
ls -= 3
compOff += 3
} else {
copy(bs[i:ls-1], bs[i+1:])
ls--
compOff++
}
escapedDot = bs[i] == '.'
bsFresh = false
continue
}
if bs[i] == '.' {
if i > 0 && bs[i-1] == '.' && !escapedDot {
wasDot = false
case '.':
if i == 0 && len(s) > 1 {
// leading dots are not legal except for the root zone
return len(msg), ErrRdata
}
if wasDot {
// two dots back to back is not legal
return lenmsg, labels, ErrRdata
return len(msg), ErrRdata
}
if i-begin >= 1<<6 { // top two bits of length must be clear
return lenmsg, labels, ErrRdata
wasDot = true
labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear
return len(msg), ErrRdata
}
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
if off+1 > lenmsg {
return lenmsg, labels, ErrBuf
}
if msg != nil {
msg[off] = byte(i - begin)
}
offset := off
off++
for j := begin; j < i; j++ {
if off+1 > lenmsg {
return lenmsg, labels, ErrBuf
}
if msg != nil {
msg[off] = bs[j]
}
off++
}
if compress && !bsFresh {
roBs = string(bs)
bsFresh = true
if off+1+labelLen > len(msg) {
return len(msg), ErrBuf
}
// Don't try to compress '.'
// We should only compress when compress it true, but we should also still pick
// We should only compress when compress is true, but we should also still pick
// up names that can be used for *future* compression(s).
if compression != nil && roBs[begin:] != "." {
if p, ok := compression[roBs[begin:]]; !ok {
// Only offsets smaller than this can be used.
if offset < maxCompressionOffset {
compression[roBs[begin:]] = offset
}
} else {
if compression.valid() && !isRootLabel(s, bs, begin, ls) {
if p, ok := compression.find(s[compBegin:]); ok {
// The first hit is the longest matching dname
// keep the pointer offset we get back and store
// the offset of the current name, because that's
// where we need to insert the pointer later
// If compress is true, we're allowed to compress this dname
if pointer == -1 && compress {
pointer = p // Where to point to
nameoffset = offset // Where to point from
break
if compress {
pointer = p // Where to point to
break loop
}
} else if off < maxCompressionOffset {
// Only offsets smaller than maxCompressionOffset can be used.
compression.insert(s[compBegin:], off)
}
}
labels++
// The following is covered by the length check above.
msg[off] = byte(labelLen)
if bs == nil {
copy(msg[off+1:], s[begin:i])
} else {
copy(msg[off+1:], bs[begin:i])
}
off += 1 + labelLen
begin = i + 1
compBegin = begin + compOff
default:
wasDot = false
}
escapedDot = false
}
// Root label is special
if len(bs) == 1 && bs[0] == '.' {
return off, labels, nil
if isRootLabel(s, bs, 0, ls) {
return off, nil
}
// If we did compression and we find something add the pointer here
if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in
// if msg == nil, we will never do compression
binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000))
off = nameoffset + 1
goto End
binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
return off + 2, nil
}
if msg != nil && off < len(msg) {
if off < len(msg) {
msg[off] = 0
}
End:
off++
return off, labels, nil
return off + 1, nil
}
// isRootLabel returns whether s or bs, from off to end, is the root
// label ".".
//
// If bs is nil, s will be checked, otherwise bs will be checked.
func isRootLabel(s string, bs []byte, off, end int) bool {
if bs == nil {
return s[off:end] == "."
}
return end-off == 1 && bs[off] == '.'
}
// Unpack a domain name.
......@@ -330,12 +370,16 @@ End:
// In theory, the pointers are only allowed to jump backward.
// We let them jump anywhere and stop jumping after a while.
// UnpackDomainName unpacks a domain name into a string.
// UnpackDomainName unpacks a domain name into a string. It returns
// the name, the new offset into msg and any error that occurred.
//
// When an error is encountered, the unpacked name will be discarded
// and len(msg) will be returned as the offset.
func UnpackDomainName(msg []byte, off int) (string, int, error) {
s := make([]byte, 0, 64)
s := make([]byte, 0, maxDomainNamePresentationLength)
off1 := 0
lenmsg := len(msg)
maxLen := maxDomainNameWireOctets
budget := maxDomainNameWireOctets
ptr := 0 // number of pointers followed
Loop:
for {
......@@ -354,30 +398,17 @@ Loop:
if off+c > lenmsg {
return "", lenmsg, ErrBuf
}
for j := off; j < off+c; j++ {
switch b := msg[j]; b {
case '.', '(', ')', ';', ' ', '@':
fallthrough
case '"', '\\':
budget -= c + 1 // +1 for the label separator
if budget <= 0 {
return "", lenmsg, ErrLongDomain
}
for _, b := range msg[off : off+c] {
if isDomainNameLabelSpecial(b) {
s = append(s, '\\', b)
// presentation-format \X escapes add an extra byte
maxLen++
default:
if b < 32 || b >= 127 { // unprintable, use \DDD
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
// presentation-format \DDD escapes add 3 extra bytes
maxLen += 3
} else {
s = append(s, b)
}
} else if b < ' ' || b > '~' {
s = append(s, escapeByte(b)...)
} else {
s = append(s, b)
}
}
s = append(s, '.')
......@@ -396,7 +427,7 @@ Loop:
if ptr == 0 {
off1 = off
}
if ptr++; ptr > 10 {
if ptr++; ptr > maxCompressionPointers {
return "", lenmsg, &Error{err: "too many compression pointers"}
}
// pointer should guarantee that it advances and points forwards at least
......@@ -412,10 +443,7 @@ Loop:
off1 = off
}
if len(s) == 0 {
s = []byte(".")
} else if len(s) >= maxLen {
// error if the name is too long, but don't throw it away
return string(s), lenmsg, ErrLongDomain
return ".", off1, nil
}
return string(s), off1, nil
}
......@@ -429,11 +457,11 @@ func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
return offset, nil
}
var err error
for i := range txt {
if len(txt[i]) > len(tmp) {
for _, s := range txt {
if len(s) > len(tmp) {
return offset, ErrBuf
}
offset, err = packTxtString(txt[i], msg, offset, tmp)
offset, err = packTxtString(s, msg, offset, tmp)
if err != nil {
return offset, err
}
......@@ -512,7 +540,7 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
off = off0
var s string
for off < len(msg) && err == nil {
s, off, err = unpackTxtString(msg, off)
s, off, err = unpackString(msg, off)
if err == nil {
ss = append(ss, s)
}
......@@ -520,43 +548,16 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
return
}
func unpackTxtString(msg []byte, offset int) (string, int, error) {
if offset+1 > len(msg) {
return "", offset, &Error{err: "overflow unpacking txt"}
}
l := int(msg[offset])
if offset+l+1 > len(msg) {
return "", offset, &Error{err: "overflow unpacking txt"}
}
s := make([]byte, 0, l)
for _, b := range msg[offset+1 : offset+1+l] {
switch b {
case '"', '\\':
s = append(s, '\\', b)
default:
if b < 32 || b > 127 { // unprintable
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
}
}
}
offset += 1 + l
return string(s), offset, nil
}
// Helpers for dealing with escaped bytes
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func dddToByte(s []byte) byte {
_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
func dddStringToByte(s string) byte {
_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
......@@ -574,19 +575,38 @@ func intToBytes(i *big.Int, length int) []byte {
// PackRR packs a resource record rr into msg[off:].
// See PackDomainName for documentation about the compression.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
if err == nil {
// packRR no longer sets the Rdlength field on the rr, but
// callers might be expecting it so we set it here.
rr.Header().Rdlength = uint16(off1 - headerEnd)
}
return off1, err
}
func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
if rr == nil {
return len(msg), &Error{err: "nil rr"}
return len(msg), len(msg), &Error{err: "nil rr"}
}
off1, err = rr.pack(msg, off, compression, compress)
headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
if err != nil {
return len(msg), err
return headerEnd, len(msg), err
}
// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
if rawSetRdlength(msg, off, off1) {
return off1, nil
off1, err = rr.pack(msg, headerEnd, compression, compress)
if err != nil {
return headerEnd, len(msg), err
}
return off, ErrRdata
rdlength := off1 - headerEnd
if int(uint16(rdlength)) != rdlength { // overflow
return headerEnd, len(msg), ErrRdata
}
// The RDLENGTH field is the last field in the header and we set it here.
binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
return headerEnd, off1, nil
}
// UnpackRR unpacks msg[off:] into an RR.
......@@ -602,17 +622,35 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// UnpackRRWithHeader unpacks the record type specific payload given an existing
// RR_Header.
func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
if newFn, ok := TypeToRR[h.Rrtype]; ok {
rr = newFn()
*rr.Header() = h
} else {
rr = &RFC3597{Hdr: h}
}
if off < 0 || off > len(msg) {
return &h, off, &Error{err: "bad off"}
}
end := off + int(h.Rdlength)
if end < off || end > len(msg) {
return &h, end, &Error{err: "bad rdlength"}
}
if fn, known := typeToUnpack[h.Rrtype]; !known {
rr, off, err = unpackRFC3597(h, msg, off)
} else {
rr, off, err = fn(h, msg, off)
if noRdata(h) {
return rr, off, nil
}
off, err = rr.unpack(msg, off)
if err != nil {
return nil, end, err
}
if off != end {
return &h, end, &Error{err: "bad rdlength"}
}
return rr, off, err
return rr, off, nil
}
// unpackRRslice unpacks msg[off:] into an []RR.
......@@ -630,7 +668,6 @@ func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error)
}
// If offset does not increase anymore, l is a lie
if off1 == off {
l = i
break
}
dst = append(dst, r)
......@@ -693,32 +730,33 @@ func (dns *Msg) Pack() (msg []byte, err error) {
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
var compression map[string]int
if dns.Compress {
compression = make(map[string]int) // Compression pointer mappings.
// If this message can't be compressed, avoid filling the
// compression map and creating garbage.
if dns.Compress && dns.isCompressible() {
compression := make(map[string]uint16) // Compression pointer mappings.
return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
}
return dns.packBufferWithCompressionMap(buf, compression)
return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
}
// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]int) (msg []byte, err error) {
// We use a similar function in tsig.go's stripTsig.
var dh Header
func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
if dns.Rcode < 0 || dns.Rcode > 0xFFF {
return nil, ErrRcode
}
if dns.Rcode > 0xF {
// Regular RCODE field is 4 bits
opt := dns.IsEdns0()
if opt == nil {
return nil, ErrExtendedRcode
}
opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
// Set extended rcode unconditionally if we have an opt, this will allow
// resetting the extended rcode bits if they need to.
if opt := dns.IsEdns0(); opt != nil {
opt.SetExtendedRcode(uint16(dns.Rcode))
} else if dns.Rcode > 0xF {
// If Rcode is an extended one and opt is nil, error out.
return nil, ErrExtendedRcode
}
// Convert convenient Msg into wire-like Header.
var dh Header
dh.Id = dns.Id
dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
if dns.Response {
......@@ -746,50 +784,44 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]
dh.Bits |= _CD
}
// Prepare variable sized arrays.
question := dns.Question
answer := dns.Answer
ns := dns.Ns
extra := dns.Extra
dh.Qdcount = uint16(len(question))
dh.Ancount = uint16(len(answer))
dh.Nscount = uint16(len(ns))
dh.Arcount = uint16(len(extra))
dh.Qdcount = uint16(len(dns.Question))
dh.Ancount = uint16(len(dns.Answer))
dh.Nscount = uint16(len(dns.Ns))
dh.Arcount = uint16(len(dns.Extra))
// We need the uncompressed length here, because we first pack it and then compress it.
msg = buf
uncompressedLen := compressedLen(dns, false)
uncompressedLen := msgLenWithCompressionMap(dns, nil)
if packLen := uncompressedLen + 1; len(msg) < packLen {
msg = make([]byte, packLen)
}
// Pack it in: header and then the pieces.
off := 0
off, err = dh.pack(msg, off, compression, dns.Compress)
off, err = dh.pack(msg, off, compression, compress)
if err != nil {
return nil, err
}
for i := 0; i < len(question); i++ {
off, err = question[i].pack(msg, off, compression, dns.Compress)
for _, r := range dns.Question {
off, err = r.pack(msg, off, compression, compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(answer); i++ {
off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
for _, r := range dns.Answer {
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(ns); i++ {
off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
for _, r := range dns.Ns {
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
}
for i := 0; i < len(extra); i++ {
off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
for _, r := range dns.Extra {
_, off, err = packRR(r, msg, off, compression, compress)
if err != nil {
return nil, err
}
......@@ -797,28 +829,7 @@ func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]
return msg[:off], nil
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
var (
dh Header
off int
)
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
return err
}
dns.Id = dh.Id
dns.Response = dh.Bits&_QR != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = dh.Bits&_AA != 0
dns.Truncated = dh.Bits&_TC != 0
dns.RecursionDesired = dh.Bits&_RD != 0
dns.RecursionAvailable = dh.Bits&_RA != 0
dns.Zero = dh.Bits&_Z != 0
dns.AuthenticatedData = dh.Bits&_AD != 0
dns.CheckingDisabled = dh.Bits&_CD != 0
dns.Rcode = int(dh.Bits & 0xF)
func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
// If we are at the end of the message we should return *just* the
// header. This can still be useful to the caller. 9.9.9.9 sends these
// when responding with REFUSED for instance.
......@@ -837,8 +848,6 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
var q Question
q, off, err = unpackQuestion(msg, off)
if err != nil {
// Even if Truncated is set, we only will set ErrTruncated if we
// actually got the questions
return err
}
if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
......@@ -862,16 +871,29 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
// The header counts might have been wrong so we need to update it
dh.Arcount = uint16(len(dns.Extra))
// Set extended Rcode
if opt := dns.IsEdns0(); opt != nil {
dns.Rcode |= opt.ExtendedRcode()
}
if off != len(msg) {
// TODO(miek) make this an error?
// use PackOpt to let people tell how detailed the error reporting should be?
// println("dns: extra bytes in dns packet", off, "<", len(msg))
} else if dns.Truncated {
// Whether we ran into a an error or not, we want to return that it
// was truncated
err = ErrTruncated
}
return err
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
dh, off, err := unpackMsgHdr(msg, 0)
if err != nil {
return err
}
dns.setHdr(dh)
return dns.unpack(dh, msg, off)
}
// Convert a complete message to a string with dig-like output.
......@@ -884,184 +906,155 @@ func (dns *Msg) String() string {
s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
opt := dns.IsEdns0()
if opt != nil {
// OPT PSEUDOSECTION
s += opt.String() + "\n"
}
if len(dns.Question) > 0 {
s += "\n;; QUESTION SECTION:\n"
for i := 0; i < len(dns.Question); i++ {
s += dns.Question[i].String() + "\n"
for _, r := range dns.Question {
s += r.String() + "\n"
}
}
if len(dns.Answer) > 0 {
s += "\n;; ANSWER SECTION:\n"
for i := 0; i < len(dns.Answer); i++ {
if dns.Answer[i] != nil {
s += dns.Answer[i].String() + "\n"
for _, r := range dns.Answer {
if r != nil {
s += r.String() + "\n"
}
}
}
if len(dns.Ns) > 0 {
s += "\n;; AUTHORITY SECTION:\n"
for i := 0; i < len(dns.Ns); i++ {
if dns.Ns[i] != nil {
s += dns.Ns[i].String() + "\n"
for _, r := range dns.Ns {
if r != nil {
s += r.String() + "\n"
}
}
}
if len(dns.Extra) > 0 {
if len(dns.Extra) > 0 && (opt == nil || len(dns.Extra) > 1) {
s += "\n;; ADDITIONAL SECTION:\n"
for i := 0; i < len(dns.Extra); i++ {
if dns.Extra[i] != nil {
s += dns.Extra[i].String() + "\n"
for _, r := range dns.Extra {
if r != nil && r.Header().Rrtype != TypeOPT {
s += r.String() + "\n"
}
}
}
return s
}
// isCompressible returns whether the msg may be compressible.
func (dns *Msg) isCompressible() bool {
// If we only have one question, there is nothing we can ever compress.
return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
len(dns.Ns) > 0 || len(dns.Extra) > 0
}
// Len returns the message length when in (un)compressed wire format.
// If dns.Compress is true compression it is taken into account. Len()
// is provided to be a faster way to get the size of the resulting packet,
// than packing it, measuring the size and discarding the buffer.
func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) }
func compressedLenWithCompressionMap(dns *Msg, compression map[string]int) int {
l := 12 // Message header is always 12 bytes
for _, r := range dns.Question {
compressionLenHelper(compression, r.Name, l)
l += r.len()
func (dns *Msg) Len() int {
// If this message can't be compressed, avoid filling the
// compression map and creating garbage.
if dns.Compress && dns.isCompressible() {
compression := make(map[string]struct{})
return msgLenWithCompressionMap(dns, compression)
}
l += compressionLenSlice(l, compression, dns.Answer)
l += compressionLenSlice(l, compression, dns.Ns)
l += compressionLenSlice(l, compression, dns.Extra)
return l
return msgLenWithCompressionMap(dns, nil)
}
// compressedLen returns the message length when in compressed wire format
// when compress is true, otherwise the uncompressed length is returned.
func compressedLen(dns *Msg, compress bool) int {
// We always return one more than needed.
if compress {
compression := map[string]int{}
return compressedLenWithCompressionMap(dns, compression)
}
l := 12 // Message header is always 12 bytes
func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
l := headerSize
for _, r := range dns.Question {
l += r.len()
l += r.len(l, compression)
}
for _, r := range dns.Answer {
if r != nil {
l += r.len()
l += r.len(l, compression)
}
}
for _, r := range dns.Ns {
if r != nil {
l += r.len()
l += r.len(l, compression)
}
}
for _, r := range dns.Extra {
if r != nil {
l += r.len()
l += r.len(l, compression)
}
}
return l
}
func compressionLenSlice(lenp int, c map[string]int, rs []RR) int {
initLen := lenp
for _, r := range rs {
if r == nil {
continue
}
// TmpLen is to track len of record at 14bits boudaries
tmpLen := lenp
x := r.len()
// track this length, and the global length in len, while taking compression into account for both.
k, ok, _ := compressionLenSearch(c, r.Header().Name)
if ok {
// Size of x is reduced by k, but we add 1 since k includes the '.' and label descriptor take 2 bytes
// so, basically x:= x - k - 1 + 2
x += 1 - k
}
func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
if s == "" || s == "." {
return 1
}
tmpLen += compressionLenHelper(c, r.Header().Name, tmpLen)
k, ok, _ = compressionLenSearchType(c, r)
if ok {
x += 1 - k
escaped := strings.Contains(s, "\\")
if compression != nil && (compress || off < maxCompressionOffset) {
// compressionLenSearch will insert the entry into the compression
// map if it doesn't contain it.
if l, ok := compressionLenSearch(compression, s, off); ok && compress {
if escaped {
return escapedNameLen(s[:l]) + 2
}
return l + 2
}
lenp += x
tmpLen = lenp
tmpLen += compressionLenHelperType(c, r, tmpLen)
}
if escaped {
return escapedNameLen(s) + 1
}
return lenp - initLen
return len(s) + 1
}
// Put the parts of the name in the compression map, return the size in bytes added in payload
func compressionLenHelper(c map[string]int, s string, currentLen int) int {
if currentLen > maxCompressionOffset {
// We won't be able to add any label that could be re-used later anyway
return 0
}
if _, ok := c[s]; ok {
return 0
}
initLen := currentLen
pref := ""
prev := s
lbs := Split(s)
for j := 0; j < len(lbs); j++ {
pref = s[lbs[j]:]
currentLen += len(prev) - len(pref)
prev = pref
if _, ok := c[pref]; !ok {
// If first byte label is within the first 14bits, it might be re-used later
if currentLen < maxCompressionOffset {
c[pref] = currentLen
}
func escapedNameLen(s string) int {
nameLen := len(s)
for i := 0; i < len(s); i++ {
if s[i] != '\\' {
continue
}
if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
nameLen -= 3
i += 3
} else {
added := currentLen - initLen
if j > 0 {
// We added a new PTR
added += 2
}
return added
nameLen--
i++
}
}
return currentLen - initLen
return nameLen
}
// Look for each part in the compression map and returns its length,
// keep on searching so we get the longest match.
// Will return the size of compression found, whether a match has been
// found and the size of record if added in payload
func compressionLenSearch(c map[string]int, s string) (int, bool, int) {
off := 0
end := false
if s == "" { // don't bork on bogus data
return 0, false, 0
}
fullSize := 0
for {
func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
for off, end := 0, false; !end; off, end = NextLabel(s, off) {
if _, ok := c[s[off:]]; ok {
return len(s[off:]), true, fullSize + off
return off, true
}
if end {
break
if msgOff+off < maxCompressionOffset {
c[s[off:]] = struct{}{}
}
// Each label descriptor takes 2 bytes, add it
fullSize += 2
off, end = NextLabel(s, off)
}
return 0, false, fullSize + len(s)
return 0, false
}
// Copy returns a new RR which is a deep-copy of r.
func Copy(r RR) RR { r1 := r.copy(); return r1 }
func Copy(r RR) RR { return r.copy() }
// Len returns the length (in octets) of the uncompressed RR in wire format.
func Len(r RR) int { return r.len() }
func Len(r RR) int { return r.len(0, nil) }
// Copy returns a new *Msg which is a deep-copy of dns.
func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
......@@ -1077,40 +1070,27 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
}
rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
var rri int
r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
r1.Extra = rrArr[:0:len(dns.Extra)]
if len(dns.Answer) > 0 {
rrbegin := rri
for i := 0; i < len(dns.Answer); i++ {
rrArr[rri] = dns.Answer[i].copy()
rri++
}
r1.Answer = rrArr[rrbegin:rri:rri]
for _, r := range dns.Answer {
r1.Answer = append(r1.Answer, r.copy())
}
if len(dns.Ns) > 0 {
rrbegin := rri
for i := 0; i < len(dns.Ns); i++ {
rrArr[rri] = dns.Ns[i].copy()
rri++
}
r1.Ns = rrArr[rrbegin:rri:rri]
for _, r := range dns.Ns {
r1.Ns = append(r1.Ns, r.copy())
}
if len(dns.Extra) > 0 {
rrbegin := rri
for i := 0; i < len(dns.Extra); i++ {
rrArr[rri] = dns.Extra[i].copy()
rri++
}
r1.Extra = rrArr[rrbegin:rri:rri]
for _, r := range dns.Extra {
r1.Extra = append(r1.Extra, r.copy())
}
return r1
}
func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
off, err := PackDomainName(q.Name, msg, off, compression, compress)
func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := packDomainName(q.Name, msg, off, compression, compress)
if err != nil {
return off, err
}
......@@ -1151,7 +1131,7 @@ func unpackQuestion(msg []byte, off int) (Question, int, error) {
return q, off, err
}
func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := packUint16(dh.Id, msg, off)
if err != nil {
return off, err
......@@ -1173,7 +1153,10 @@ func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress
return off, err
}
off, err = packUint16(dh.Arcount, msg, off)
return off, err
if err != nil {
return off, err
}
return off, nil
}
func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
......@@ -1202,5 +1185,23 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
return dh, off, err
}
dh.Arcount, off, err = unpackUint16(msg, off)
return dh, off, err
if err != nil {
return dh, off, err
}
return dh, off, nil
}
// setHdr set the header in the dns using the binary data in dh.
func (dns *Msg) setHdr(dh Header) {
dns.Id = dh.Id
dns.Response = dh.Bits&_QR != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = dh.Bits&_AA != 0
dns.Truncated = dh.Bits&_TC != 0
dns.RecursionDesired = dh.Bits&_RD != 0
dns.RecursionAvailable = dh.Bits&_RA != 0
dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
dns.AuthenticatedData = dh.Bits&_AD != 0
dns.CheckingDisabled = dh.Bits&_CD != 0
dns.Rcode = int(dh.Bits & 0xF)
}
//+build ignore
// msg_generate.go is meant to run with go generate. It will use
// go/{importer,types} to track down all the RR struct types. Then for each type
// it will generate pack/unpack methods based on the struct tags. The generated source is
// written to zmsg.go, and is meant to be checked into git.
package main
import (
"bytes"
"fmt"
"go/format"
"go/importer"
"go/types"
"log"
"os"
"strings"
)
var packageHdr = `
// Code generated by "go run msg_generate.go"; DO NOT EDIT.
package dns
`
// getTypeStruct will take a type and the package scope, and return the
// (innermost) struct if the type is considered a RR type (currently defined as
// those structs beginning with a RR_Header, could be redefined as implementing
// the RR interface). The bool return value indicates if embedded structs were
// resolved.
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
// Import and type-check the package
pkg, err := importer.Default().Import("github.com/miekg/dns")
fatalIfErr(err)
scope := pkg.Scope()
// Collect actual types (*X)
var namedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if name == "PrivateRR" {
continue
}
// Check if corresponding TypeX exists
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
log.Fatalf("Constant Type%s does not exist.", o.Name())
}
namedTypes = append(namedTypes, o.Name())
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
fmt.Fprint(b, "// pack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, err
}
headerEnd := off
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return off, err
}
`)
}
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("off, err = packStringTxt(rr.%s, msg, off)\n")
case `dns:"opt"`:
o("off, err = packDataOpt(rr.%s, msg, off)\n")
case `dns:"nsec"`:
o("off, err = packDataNsec(rr.%s, msg, off)\n")
case `dns:"domain-name"`:
o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}
switch {
case st.Tag(i) == `dns:"-"`: // ignored
case st.Tag(i) == `dns:"cdomain-name"`:
o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
case st.Tag(i) == `dns:"domain-name"`:
o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n")
case st.Tag(i) == `dns:"a"`:
o("off, err = packDataA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"aaaa"`:
o("off, err = packDataAAAA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"uint48"`:
o("off, err = packUint48(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"txt"`:
o("off, err = packString(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
fallthrough
case st.Tag(i) == `dns:"base32"`:
o("off, err = packStringBase32(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
fallthrough
case st.Tag(i) == `dns:"base64"`:
o("off, err = packStringBase64(rr.%s, msg, off)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
// directly write instead of using o() so we get the error check in the correct place
field := st.Field(i).Name()
fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
if rr.%s != "-" {
off, err = packStringHex(rr.%s, msg, off)
if err != nil {
return off, err
}
}
`, field, field)
continue
case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
fallthrough
case st.Tag(i) == `dns:"hex"`:
o("off, err = packStringHex(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"octet"`:
o("off, err = packStringOctet(rr.%s, msg, off)\n")
case st.Tag(i) == "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("off, err = packUint8(rr.%s, msg, off)\n")
case types.Uint16:
o("off, err = packUint16(rr.%s, msg, off)\n")
case types.Uint32:
o("off, err = packUint32(rr.%s, msg, off)\n")
case types.Uint64:
o("off, err = packUint64(rr.%s, msg, off)\n")
case types.String:
o("off, err = packString(rr.%s, msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
}
// We have packed everything, only now we know the rdlength of this RR
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
fmt.Fprintln(b, "return off, nil }\n")
}
fmt.Fprint(b, "// unpack*() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
fmt.Fprintf(b, "rr := new(%s)\n", name)
fmt.Fprint(b, "rr.Hdr = h\n")
fmt.Fprint(b, `if noRdata(h) {
return rr, off, nil
}
var err error
rdStart := off
_ = rdStart
`)
for i := 1; i < st.NumFields(); i++ {
o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil {
return rr, off, err
}
`)
}
// size-* are special, because they reference a struct member we should use for the length.
if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
structMember := structMember(st.Tag(i))
structTag := structTag(st.Tag(i))
switch structTag {
case "hex":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
case "base32":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
case "base64":
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
fmt.Fprint(b, `if err != nil {
return rr, off, err
}
`)
continue
}
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"txt"`:
o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
case `dns:"opt"`:
o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
case `dns:"nsec"`:
o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
case `dns:"domain-name"`:
o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}
switch st.Tag(i) {
case `dns:"-"`: // ignored
case `dns:"cdomain-name"`:
fallthrough
case `dns:"domain-name"`:
o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
case `dns:"a"`:
o("rr.%s, off, err = unpackDataA(msg, off)\n")
case `dns:"aaaa"`:
o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
case `dns:"uint48"`:
o("rr.%s, off, err = unpackUint48(msg, off)\n")
case `dns:"txt"`:
o("rr.%s, off, err = unpackString(msg, off)\n")
case `dns:"base32"`:
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"base64"`:
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"hex"`:
o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"octet"`:
o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("rr.%s, off, err = unpackUint8(msg, off)\n")
case types.Uint16:
o("rr.%s, off, err = unpackUint16(msg, off)\n")
case types.Uint32:
o("rr.%s, off, err = unpackUint32(msg, off)\n")
case types.Uint64:
o("rr.%s, off, err = unpackUint64(msg, off)\n")
case types.String:
o("rr.%s, off, err = unpackString(msg, off)\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
// If we've hit len(msg) we return without error.
if i < st.NumFields()-1 {
fmt.Fprintf(b, `if off == len(msg) {
return rr, off, nil
}
`)
}
}
fmt.Fprintf(b, "return rr, off, err }\n\n")
}
// Generate typeToUnpack map
fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
for _, name := range namedTypes {
if name == "RFC3597" {
continue
}
fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
}
fmt.Fprintln(b, "}\n")
// gofmt
res, err := format.Source(b.Bytes())
if err != nil {
b.WriteTo(os.Stderr)
log.Fatal(err)
}
// write result
f, err := os.Create("zmsg.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)
}
// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
func structMember(s string) string {
fields := strings.Split(s, ":")
if len(fields) == 0 {
return ""
}
f := fields[len(fields)-1]
// f should have a closing "
if len(f) > 1 {
return f[:len(f)-1]
}
return f
}
// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
func structTag(s string) string {
fields := strings.Split(s, ":")
if len(fields) < 2 {
return ""
}
return fields[1][len("\"size-"):]
}
func fatalIfErr(err error) {
if err != nil {
log.Fatal(err)
}
}
......@@ -6,7 +6,8 @@ import (
"encoding/binary"
"encoding/hex"
"net"
"strconv"
"sort"
"strings"
)
// helper functions called from the generated zmsg.go
......@@ -25,12 +26,13 @@ func unpackDataA(msg []byte, off int) (net.IP, int, error) {
}
func packDataA(a net.IP, msg []byte, off int) (int, error) {
// It must be a slice of 4, even if it is 16, we encode only the first 4
if off+net.IPv4len > len(msg) {
return len(msg), &Error{err: "overflow packing a"}
}
switch len(a) {
case net.IPv4len, net.IPv6len:
// It must be a slice of 4, even if it is 16, we encode only the first 4
if off+net.IPv4len > len(msg) {
return len(msg), &Error{err: "overflow packing a"}
}
copy(msg[off:], a.To4())
off += net.IPv4len
case 0:
......@@ -51,12 +53,12 @@ func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
}
func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
if off+net.IPv6len > len(msg) {
return len(msg), &Error{err: "overflow packing aaaa"}
}
switch len(aaaa) {
case net.IPv6len:
if off+net.IPv6len > len(msg) {
return len(msg), &Error{err: "overflow packing aaaa"}
}
copy(msg[off:], aaaa)
off += net.IPv6len
case 0:
......@@ -99,14 +101,14 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
return hdr, off, msg, err
}
// pack packs an RR header, returning the offset to the end of the header.
// packHeader packs an RR header, returning the offset to the end of the header.
// See PackDomainName for documentation about the compression.
func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
if off == len(msg) {
return off, nil
}
off, err = PackDomainName(hdr.Name, msg, off, compression, compress)
off, err := packDomainName(hdr.Name, msg, off, compression, compress)
if err != nil {
return len(msg), err
}
......@@ -122,7 +124,7 @@ func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compr
if err != nil {
return len(msg), err
}
off, err = packUint16(hdr.Rdlength, msg, off)
off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
if err != nil {
return len(msg), err
}
......@@ -177,14 +179,14 @@ func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
if off+1 > len(msg) {
return 0, len(msg), &Error{err: "overflow unpacking uint8"}
}
return uint8(msg[off]), off + 1, nil
return msg[off], off + 1, nil
}
func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
if off+1 > len(msg) {
return len(msg), &Error{err: "overflow packing uint8"}
}
msg[off] = byte(i)
msg[off] = i
return off + 1, nil
}
......@@ -223,8 +225,8 @@ func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
}
// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
i = uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5]))
i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])
off += 6
return i, off, nil
}
......@@ -264,32 +266,36 @@ func unpackString(msg []byte, off int) (string, int, error) {
return "", off, &Error{err: "overflow unpacking txt"}
}
l := int(msg[off])
if off+l+1 > len(msg) {
off++
if off+l > len(msg) {
return "", off, &Error{err: "overflow unpacking txt"}
}
s := make([]byte, 0, l)
for _, b := range msg[off+1 : off+1+l] {
switch b {
case '"', '\\':
s = append(s, '\\', b)
default:
if b < 32 || b > 127 { // unprintable
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
var s strings.Builder
consumed := 0
for i, b := range msg[off : off+l] {
switch {
case b == '"' || b == '\\':
if consumed == 0 {
s.Grow(l * 2)
}
s.Write(msg[off+consumed : off+i])
s.WriteByte('\\')
s.WriteByte(b)
consumed = i + 1
case b < ' ' || b > '~': // unprintable
if consumed == 0 {
s.Grow(l * 2)
}
s.Write(msg[off+consumed : off+i])
s.WriteString(escapeByte(b))
consumed = i + 1
}
}
off += 1 + l
return string(s), off, nil
if consumed == 0 { // no escaping needed
return string(msg[off : off+l]), off + l, nil
}
s.Write(msg[off+consumed : off+l])
return s.String(), off + l, nil
}
func packString(s string, msg []byte, off int) (int, error) {
......@@ -371,6 +377,22 @@ func packStringHex(s string, msg []byte, off int) (int, error) {
return off, nil
}
func unpackStringAny(msg []byte, off, end int) (string, int, error) {
if end > len(msg) {
return "", len(msg), &Error{err: "overflow unpacking anything"}
}
return string(msg[off:end]), end, nil
}
func packStringAny(s string, msg []byte, off int) (int, error) {
if off+len(s) > len(msg) {
return len(msg), &Error{err: "overflow packing anything"}
}
copy(msg[off:off+len(s)], s)
off += len(s)
return off, nil
}
func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
txt, off, err := unpackTxt(msg, off)
if err != nil {
......@@ -391,7 +413,7 @@ func packStringTxt(s []string, msg []byte, off int) (int, error) {
func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
var edns []EDNS0
Option:
code := uint16(0)
var code uint16
if off+4 > len(msg) {
return nil, len(msg), &Error{err: "overflow unpacking opt"}
}
......@@ -402,79 +424,12 @@ Option:
if off+int(optlen) > len(msg) {
return nil, len(msg), &Error{err: "overflow unpacking opt"}
}
switch code {
case EDNS0NSID:
e := new(EDNS0_NSID)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0SUBNET:
e := new(EDNS0_SUBNET)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0COOKIE:
e := new(EDNS0_COOKIE)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0UL:
e := new(EDNS0_UL)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0LLQ:
e := new(EDNS0_LLQ)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0DAU:
e := new(EDNS0_DAU)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0DHU:
e := new(EDNS0_DHU)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0N3U:
e := new(EDNS0_N3U)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
case EDNS0PADDING:
e := new(EDNS0_PADDING)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
default:
e := new(EDNS0_LOCAL)
e.Code = code
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
e := makeDataOpt(code)
if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
return nil, len(msg), err
}
edns = append(edns, e)
off += int(optlen)
if off < len(msg) {
goto Option
......@@ -486,16 +441,14 @@ Option:
func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
for _, el := range options {
b, err := el.pack()
if err != nil || off+3 > len(msg) {
if err != nil || off+4 > len(msg) {
return len(msg), &Error{err: "overflow packing opt"}
}
binary.BigEndian.PutUint16(msg[off:], el.Option()) // Option code
binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
off += 4
if off+len(b) > len(msg) {
copy(msg[off:], b)
off = len(msg)
continue
return len(msg), &Error{err: "overflow packing opt"}
}
// Actual data
copy(msg[off:off+len(b)], b)
......@@ -523,7 +476,7 @@ func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
length, window, lastwindow := 0, 0, -1
for off < len(msg) {
if off+2 > len(msg) {
return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
return nsec, len(msg), &Error{err: "overflow unpacking NSEC(3)"}
}
window = int(msg[off])
length = int(msg[off+1])
......@@ -531,22 +484,21 @@ func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
if window <= lastwindow {
// RFC 4034: Blocks are present in the NSEC RR RDATA in
// increasing numerical order.
return nsec, len(msg), &Error{err: "out of order NSEC block"}
return nsec, len(msg), &Error{err: "out of order NSEC(3) block in type bitmap"}
}
if length == 0 {
// RFC 4034: Blocks with no types present MUST NOT be included.
return nsec, len(msg), &Error{err: "empty NSEC block"}
return nsec, len(msg), &Error{err: "empty NSEC(3) block in type bitmap"}
}
if length > 32 {
return nsec, len(msg), &Error{err: "NSEC block too long"}
return nsec, len(msg), &Error{err: "NSEC(3) block too long in type bitmap"}
}
if off+length > len(msg) {
return nsec, len(msg), &Error{err: "overflowing NSEC block"}
return nsec, len(msg), &Error{err: "overflowing NSEC(3) block in type bitmap"}
}
// Walk the bytes in the window and extract the type bits
for j := 0; j < length; j++ {
b := msg[off+j]
for j, b := range msg[off : off+length] {
// Check the bits one by one, and set the type
if b&0x80 == 0x80 {
nsec = append(nsec, uint16(window*256+j*8+0))
......@@ -579,13 +531,45 @@ func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
return nsec, off, nil
}
// typeBitMapLen is a helper function which computes the "maximum" length of
// a the NSEC Type BitMap field.
func typeBitMapLen(bitmap []uint16) int {
var l int
var lastwindow, lastlength uint16
for _, t := range bitmap {
window := t / 256
length := (t-window*256)/8 + 1
if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
l += int(lastlength) + 2
lastlength = 0
}
if window < lastwindow || length < lastlength {
// packDataNsec would return Error{err: "nsec bits out of order"} here, but
// when computing the length, we want do be liberal.
continue
}
lastwindow, lastlength = window, length
}
l += int(lastlength) + 2
return l
}
func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
if len(bitmap) == 0 {
return off, nil
}
if off > len(msg) {
return off, &Error{err: "overflow packing nsec"}
}
toZero := msg[off:]
if maxLen := typeBitMapLen(bitmap); maxLen < len(toZero) {
toZero = toZero[:maxLen]
}
for i := range toZero {
toZero[i] = 0
}
var lastwindow, lastlength uint16
for j := 0; j < len(bitmap); j++ {
t := bitmap[j]
for _, t := range bitmap {
window := t / 256
length := (t-window*256)/8 + 1
if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
......@@ -610,6 +594,65 @@ func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
return off, nil
}
func unpackDataSVCB(msg []byte, off int) ([]SVCBKeyValue, int, error) {
var xs []SVCBKeyValue
var code uint16
var length uint16
var err error
for off < len(msg) {
code, off, err = unpackUint16(msg, off)
if err != nil {
return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
}
length, off, err = unpackUint16(msg, off)
if err != nil || off+int(length) > len(msg) {
return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
}
e := makeSVCBKeyValue(SVCBKey(code))
if e == nil {
return nil, len(msg), &Error{err: "bad SVCB key"}
}
if err := e.unpack(msg[off : off+int(length)]); err != nil {
return nil, len(msg), err
}
if len(xs) > 0 && e.Key() <= xs[len(xs)-1].Key() {
return nil, len(msg), &Error{err: "SVCB keys not in strictly increasing order"}
}
xs = append(xs, e)
off += int(length)
}
return xs, off, nil
}
func packDataSVCB(pairs []SVCBKeyValue, msg []byte, off int) (int, error) {
pairs = append([]SVCBKeyValue(nil), pairs...)
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Key() < pairs[j].Key()
})
prev := svcb_RESERVED
for _, el := range pairs {
if el.Key() == prev {
return len(msg), &Error{err: "repeated SVCB keys are not allowed"}
}
prev = el.Key()
packed, err := el.pack()
if err != nil {
return len(msg), err
}
off, err = packUint16(uint16(el.Key()), msg, off)
if err != nil {
return len(msg), &Error{err: "overflow packing SVCB"}
}
off, err = packUint16(uint16(len(packed)), msg, off)
if err != nil || off+len(packed) > len(msg) {
return len(msg), &Error{err: "overflow packing SVCB"}
}
copy(msg[off:off+len(packed)], packed)
off += len(packed)
}
return off, nil
}
func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
var (
servers []string
......@@ -629,13 +672,141 @@ func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
return servers, off, nil
}
func packDataDomainNames(names []string, msg []byte, off int, compression map[string]int, compress bool) (int, error) {
func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
var err error
for _, name := range names {
off, err = packDomainName(name, msg, off, compression, compress)
if err != nil {
return len(msg), err
}
}
return off, nil
}
func packDataApl(data []APLPrefix, msg []byte, off int) (int, error) {
var err error
for j := 0; j < len(names); j++ {
off, err = PackDomainName(names[j], msg, off, compression, false && compress)
for i := range data {
off, err = packDataAplPrefix(&data[i], msg, off)
if err != nil {
return len(msg), err
}
}
return off, nil
}
func packDataAplPrefix(p *APLPrefix, msg []byte, off int) (int, error) {
if len(p.Network.IP) != len(p.Network.Mask) {
return len(msg), &Error{err: "address and mask lengths don't match"}
}
var err error
prefix, _ := p.Network.Mask.Size()
addr := p.Network.IP.Mask(p.Network.Mask)[:(prefix+7)/8]
switch len(p.Network.IP) {
case net.IPv4len:
off, err = packUint16(1, msg, off)
case net.IPv6len:
off, err = packUint16(2, msg, off)
default:
err = &Error{err: "unrecognized address family"}
}
if err != nil {
return len(msg), err
}
off, err = packUint8(uint8(prefix), msg, off)
if err != nil {
return len(msg), err
}
var n uint8
if p.Negation {
n = 0x80
}
// trim trailing zero bytes as specified in RFC3123 Sections 4.1 and 4.2.
i := len(addr) - 1
for ; i >= 0 && addr[i] == 0; i-- {
}
addr = addr[:i+1]
adflen := uint8(len(addr)) & 0x7f
off, err = packUint8(n|adflen, msg, off)
if err != nil {
return len(msg), err
}
if off+len(addr) > len(msg) {
return len(msg), &Error{err: "overflow packing APL prefix"}
}
off += copy(msg[off:], addr)
return off, nil
}
func unpackDataApl(msg []byte, off int) ([]APLPrefix, int, error) {
var result []APLPrefix
for off < len(msg) {
prefix, end, err := unpackDataAplPrefix(msg, off)
if err != nil {
return nil, len(msg), err
}
off = end
result = append(result, prefix)
}
return result, off, nil
}
func unpackDataAplPrefix(msg []byte, off int) (APLPrefix, int, error) {
family, off, err := unpackUint16(msg, off)
if err != nil {
return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
}
prefix, off, err := unpackUint8(msg, off)
if err != nil {
return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
}
nlen, off, err := unpackUint8(msg, off)
if err != nil {
return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
}
var ip []byte
switch family {
case 1:
ip = make([]byte, net.IPv4len)
case 2:
ip = make([]byte, net.IPv6len)
default:
return APLPrefix{}, len(msg), &Error{err: "unrecognized APL address family"}
}
if int(prefix) > 8*len(ip) {
return APLPrefix{}, len(msg), &Error{err: "APL prefix too long"}
}
afdlen := int(nlen & 0x7f)
if afdlen > len(ip) {
return APLPrefix{}, len(msg), &Error{err: "APL length too long"}
}
if off+afdlen > len(msg) {
return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL address"}
}
// Address MUST NOT contain trailing zero bytes per RFC3123 Sections 4.1 and 4.2.
off += copy(ip, msg[off:off+afdlen])
if afdlen > 0 {
last := ip[afdlen-1]
if last == 0 {
return APLPrefix{}, len(msg), &Error{err: "extra APL address bits"}
}
}
ipnet := net.IPNet{
IP: ip,
Mask: net.CIDRMask(int(prefix), 8*len(ip)),
}
return APLPrefix{
Negation: (nlen & 0x80) != 0,
Network: ipnet,
}, off, nil
}
package dns
// Truncate ensures the reply message will fit into the requested buffer
// size by removing records that exceed the requested size.
//
// It will first check if the reply fits without compression and then with
// compression. If it won't fit with compression, Truncate then walks the
// record adding as many records as possible without exceeding the
// requested buffer size.
//
// If the message fits within the requested size without compression,
// Truncate will set the message's Compress attribute to false. It is
// the caller's responsibility to set it back to true if they wish to
// compress the payload regardless of size.
//
// The TC bit will be set if any records were excluded from the message.
// If the TC bit is already set on the message it will be retained.
// TC indicates that the client should retry over TCP.
//
// According to RFC 2181, the TC bit should only be set if not all of the
// "required" RRs can be included in the response. Unfortunately, we have
// no way of knowing which RRs are required so we set the TC bit if any RR
// had to be omitted from the response.
//
// The appropriate buffer size can be retrieved from the requests OPT
// record, if present, and is transport specific otherwise. dns.MinMsgSize
// should be used for UDP requests without an OPT record, and
// dns.MaxMsgSize for TCP requests without an OPT record.
func (dns *Msg) Truncate(size int) {
if dns.IsTsig() != nil {
// To simplify this implementation, we don't perform
// truncation on responses with a TSIG record.
return
}
// RFC 6891 mandates that the payload size in an OPT record
// less than 512 (MinMsgSize) bytes must be treated as equal to 512 bytes.
//
// For ease of use, we impose that restriction here.
if size < MinMsgSize {
size = MinMsgSize
}
l := msgLenWithCompressionMap(dns, nil) // uncompressed length
if l <= size {
// Don't waste effort compressing this message.
dns.Compress = false
return
}
dns.Compress = true
edns0 := dns.popEdns0()
if edns0 != nil {
// Account for the OPT record that gets added at the end,
// by subtracting that length from our budget.
//
// The EDNS(0) OPT record must have the root domain and
// it's length is thus unaffected by compression.
size -= Len(edns0)
}
compression := make(map[string]struct{})
l = headerSize
for _, r := range dns.Question {
l += r.len(l, compression)
}
var numAnswer int
if l < size {
l, numAnswer = truncateLoop(dns.Answer, size, l, compression)
}
var numNS int
if l < size {
l, numNS = truncateLoop(dns.Ns, size, l, compression)
}
var numExtra int
if l < size {
_, numExtra = truncateLoop(dns.Extra, size, l, compression)
}
// See the function documentation for when we set this.
dns.Truncated = dns.Truncated || len(dns.Answer) > numAnswer ||
len(dns.Ns) > numNS || len(dns.Extra) > numExtra
dns.Answer = dns.Answer[:numAnswer]
dns.Ns = dns.Ns[:numNS]
dns.Extra = dns.Extra[:numExtra]
if edns0 != nil {
// Add the OPT record back onto the additional section.
dns.Extra = append(dns.Extra, edns0)
}
}
func truncateLoop(rrs []RR, size, l int, compression map[string]struct{}) (int, int) {
for i, r := range rrs {
if r == nil {
continue
}
l += r.len(l, compression)
if l > size {
// Return size, rather than l prior to this record,
// to prevent any further records being added.
return size, i
}
if l == size {
return l, i + 1
}
}
return l, len(rrs)
}
......@@ -2,53 +2,48 @@ package dns
import (
"crypto/sha1"
"hash"
"encoding/hex"
"strings"
)
type saltWireFmt struct {
Salt string `dns:"size-hex"`
}
// HashName hashes a string (label) according to RFC 5155. It returns the hashed string in uppercase.
func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt)
saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize)
n, err := packSaltWire(saltwire, wire)
if ha != SHA1 {
return ""
}
wireSalt := make([]byte, hex.DecodedLen(len(salt)))
n, err := packStringHex(salt, wireSalt, 0)
if err != nil {
return ""
}
wire = wire[:n]
wireSalt = wireSalt[:n]
name := make([]byte, 255)
off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if err != nil {
return ""
}
name = name[:off]
var s hash.Hash
switch ha {
case SHA1:
s = sha1.New()
default:
return ""
}
s := sha1.New()
// k = 0
s.Write(name)
s.Write(wire)
s.Write(wireSalt)
nsec3 := s.Sum(nil)
// k > 0
for k := uint16(0); k < iter; k++ {
s.Reset()
s.Write(nsec3)
s.Write(wire)
s.Write(wireSalt)
nsec3 = s.Sum(nsec3[:0])
}
return toBase32(nsec3)
}
// Cover returns true if a name is covered by the NSEC3 record
// Cover returns true if a name is covered by the NSEC3 record.
func (rr *NSEC3) Cover(name string) bool {
nameHash := HashName(name, rr.Hash, rr.Iterations, rr.Salt)
owner := strings.ToUpper(rr.Hdr.Name)
......@@ -63,8 +58,10 @@ func (rr *NSEC3) Cover(name string) bool {
}
nextHash := rr.NextDomain
if ownerHash == nextHash { // empty interval
return false
// if empty interval found, try cover wildcard hashes so nameHash shouldn't match with ownerHash
if ownerHash == nextHash && nameHash != ownerHash { // empty interval
return true
}
if ownerHash > nextHash { // end of zone
if nameHash > ownerHash { // covered since there is nothing after ownerHash
......@@ -96,11 +93,3 @@ func (rr *NSEC3) Match(name string) bool {
}
return false
}
func packSaltWire(sw *saltWireFmt, msg []byte) (int, error) {
off, err := packStringHex(sw.Salt, msg, 0)
if err != nil {
return off, err
}
return off, nil
}
package dns
import (
"fmt"
"strings"
)
import "strings"
// PrivateRdata is an interface used for implementing "Private Use" RR types, see
// RFC 6895. This allows one to experiment with new RR types, without requesting an
// official type code. Also see dns.PrivateHandle and dns.PrivateHandleRemove.
type PrivateRdata interface {
// String returns the text presentaton of the Rdata of the Private RR.
// String returns the text presentation of the Rdata of the Private RR.
String() string
// Parse parses the Rdata of the private RR.
Parse([]string) error
// Pack is used when packing a private RR into a buffer.
Pack([]byte) (int, error)
// Unpack is used when unpacking a private RR from a buffer.
// TODO(miek): diff. signature than Pack, see edns0.go for instance.
Unpack([]byte) (int, error)
// Copy copies the Rdata.
// Copy copies the Rdata into the PrivateRdata argument.
Copy(PrivateRdata) error
// Len returns the length in octets of the Rdata.
Len() int
......@@ -29,21 +25,8 @@ type PrivateRdata interface {
type PrivateRR struct {
Hdr RR_Header
Data PrivateRdata
}
func mkPrivateRR(rrtype uint16) *PrivateRR {
// Panics if RR is not an instance of PrivateRR.
rrfunc, ok := TypeToRR[rrtype]
if !ok {
panic(fmt.Sprintf("dns: invalid operation with Private RR type %d", rrtype))
}
anyrr := rrfunc()
switch rr := anyrr.(type) {
case *PrivateRR:
return rr
}
panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr))
generator func() PrivateRdata // for copy
}
// Header return the RR header of r.
......@@ -52,97 +35,79 @@ func (r *PrivateRR) Header() *RR_Header { return &r.Hdr }
func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() }
// Private len and copy parts to satisfy RR interface.
func (r *PrivateRR) len() int { return r.Hdr.len() + r.Data.Len() }
func (r *PrivateRR) len(off int, compression map[string]struct{}) int {
l := r.Hdr.len(off, compression)
l += r.Data.Len()
return l
}
func (r *PrivateRR) copy() RR {
// make new RR like this:
rr := mkPrivateRR(r.Hdr.Rrtype)
rr.Hdr = r.Hdr
rr := &PrivateRR{r.Hdr, r.generator(), r.generator}
err := r.Data.Copy(rr.Data)
if err != nil {
panic("dns: got value that could not be used to copy Private rdata")
if err := r.Data.Copy(rr.Data); err != nil {
panic("dns: got value that could not be used to copy Private rdata: " + err.Error())
}
return rr
}
func (r *PrivateRR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
off, err := r.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, err
}
headerEnd := off
func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
n, err := r.Data.Pack(msg[off:])
if err != nil {
return len(msg), err
}
off += n
r.Header().Rdlength = uint16(off - headerEnd)
return off, nil
}
// PrivateHandle registers a private resource record type. It requires
// string and numeric representation of private RR type and generator function as argument.
func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) {
rtypestr = strings.ToUpper(rtypestr)
TypeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator()} }
TypeToString[rtype] = rtypestr
StringToType[rtypestr] = rtype
func (r *PrivateRR) unpack(msg []byte, off int) (int, error) {
off1, err := r.Data.Unpack(msg[off:])
off += off1
return off, err
}
typeToUnpack[rtype] = func(h RR_Header, msg []byte, off int) (RR, int, error) {
if noRdata(h) {
return &h, off, nil
func (r *PrivateRR) parse(c *zlexer, origin string) *ParseError {
var l lex
text := make([]string, 0, 2) // could be 0..N elements, median is probably 1
Fetch:
for {
// TODO(miek): we could also be returning _QUOTE, this might or might not
// be an issue (basically parsing TXT becomes hard)
switch l, _ = c.Next(); l.value {
case zNewline, zEOF:
break Fetch
case zString:
text = append(text, l.token)
}
var err error
rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h
}
off1, err := rr.Data.Unpack(msg[off:])
off += off1
if err != nil {
return rr, off, err
}
return rr, off, err
err := r.Data.Parse(text)
if err != nil {
return &ParseError{"", err.Error(), l}
}
setPrivateRR := func(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h
var l lex
text := make([]string, 0, 2) // could be 0..N elements, median is probably 1
Fetch:
for {
// TODO(miek): we could also be returning _QUOTE, this might or might not
// be an issue (basically parsing TXT becomes hard)
switch l = <-c; l.value {
case zNewline, zEOF:
break Fetch
case zString:
text = append(text, l.token)
}
}
return nil
}
err := rr.Data.Parse(text)
if err != nil {
return nil, &ParseError{f, err.Error(), l}, ""
}
func (r *PrivateRR) isDuplicate(r2 RR) bool { return false }
return rr, nil, ""
}
// PrivateHandle registers a private resource record type. It requires
// string and numeric representation of private RR type and generator function as argument.
func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) {
rtypestr = strings.ToUpper(rtypestr)
typeToparserFunc[rtype] = parserFunc{setPrivateRR, true}
TypeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator(), generator} }
TypeToString[rtype] = rtypestr
StringToType[rtypestr] = rtype
}
// PrivateHandleRemove removes defenitions required to support private RR type.
// PrivateHandleRemove removes definitions required to support private RR type.
func PrivateHandleRemove(rtype uint16) {
rtypestr, ok := TypeToString[rtype]
if ok {
delete(TypeToRR, rtype)
delete(TypeToString, rtype)
delete(typeToparserFunc, rtype)
delete(StringToType, rtypestr)
delete(typeToUnpack, rtype)
}
return
}
package dns
import "encoding/binary"
// rawSetRdlength sets the rdlength in the header of
// the RR. The offset 'off' must be positioned at the
// start of the header of the RR, 'end' must be the
// end of the RR.
func rawSetRdlength(msg []byte, off, end int) bool {
l := len(msg)
Loop:
for {
if off+1 > l {
return false
}
c := int(msg[off])
off++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// End of the domainname
break Loop
}
if off+c > l {
return false
}
off += c
case 0xC0:
// pointer, next byte included, ends domainname
off++
break Loop
}
}
// The domainname has been seen, we at the start of the fixed part in the header.
// Type is 2 bytes, class is 2 bytes, ttl 4 and then 2 bytes for the length.
off += 2 + 2 + 4
if off+2 > l {
return false
}
//off+1 is the end of the header, 'end' is the end of the rr
//so 'end' - 'off+2' is the length of the rdata
rdatalen := end - (off + 2)
if rdatalen > 0xFFFF {
return false
}
binary.BigEndian.PutUint16(msg[off:], uint16(rdatalen))
return true
}
......@@ -12,6 +12,20 @@ var StringToOpcode = reverseInt(OpcodeToString)
// StringToRcode is a map of rcodes to strings.
var StringToRcode = reverseInt(RcodeToString)
func init() {
// Preserve previous NOTIMP typo, see github.com/miekg/dns/issues/733.
StringToRcode["NOTIMPL"] = RcodeNotImplemented
}
// StringToAlgorithm is the reverse of AlgorithmToString.
var StringToAlgorithm = reverseInt8(AlgorithmToString)
// StringToHash is a map of names to hash IDs.
var StringToHash = reverseInt8(HashToString)
// StringToCertType is the reverseof CertTypeToString.
var StringToCertType = reverseInt16(CertTypeToString)
// Reverse a map
func reverseInt8(m map[uint8]string) map[string]uint8 {
n := make(map[string]uint8, len(m))
......
......@@ -15,10 +15,11 @@ func Dedup(rrs []RR, m map[string]RR) []RR {
for _, r := range rrs {
key := normalizedString(r)
keys = append(keys, &key)
if _, ok := m[key]; ok {
if mr, ok := m[key]; ok {
// Shortest TTL wins.
if m[key].Header().Ttl > r.Header().Ttl {
m[key].Header().Ttl = r.Header().Ttl
rh, mrh := r.Header(), mr.Header()
if mrh.Ttl > rh.Ttl {
mrh.Ttl = rh.Ttl
}
continue
}
......
package dns
import (
"bufio"
"fmt"
"io"
"os"
......@@ -10,7 +11,10 @@ import (
)
const maxTok = 2048 // Largest token we can return.
const maxUint16 = 1<<16 - 1
// The maximum depth of $INCLUDE directives supported by the
// ZoneParser API.
const maxIncludeDepth = 7
// Tokinize a RFC 1035 zone file. The tokenizer will normalize it:
// * Add ownernames if they are left blank;
......@@ -75,25 +79,12 @@ func (e *ParseError) Error() (s string) {
}
type lex struct {
token string // text of the token
tokenUpper string // uppercase text of the token
length int // length of the token
err bool // when true, token text has lexer error
value uint8 // value: zString, _BLANK, etc.
line int // line in the file
column int // column in the file
torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar
comment string // any comment text seen
}
// Token holds the token that are returned when a zone file is parsed.
type Token struct {
// The scanned resource record when error is not nil.
RR
// When an error occurred, this has the error specifics.
Error *ParseError
// A potential comment positioned after the RR and on the same line.
Comment string
token string // text of the token
err bool // when true, token text has lexer error
value uint8 // value: zString, _BLANK, etc.
torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar
line int // line in the file
column int // column in the file
}
// ttlState describes the state necessary to fill in an omitted RR TTL
......@@ -102,11 +93,12 @@ type ttlState struct {
isByDirective bool // isByDirective indicates whether ttl was set by a $TTL directive
}
// NewRR reads the RR contained in the string s. Only the first RR is
// returned. If s contains no RR, return nil with no error. The class
// defaults to IN and TTL defaults to 3600. The full zone file syntax
// like $TTL, $ORIGIN, etc. is supported. All fields of the returned
// RR are set, except RR.Header().Rdlength which is set to 0.
// NewRR reads the RR contained in the string s. Only the first RR is returned.
// If s contains no records, NewRR will return nil with no error.
//
// The class defaults to IN and TTL defaults to 3600. The full zone file syntax
// like $TTL, $ORIGIN, etc. is supported. All fields of the returned RR are
// set, except RR.Header().Rdlength which is set to 0.
func NewRR(s string) (RR, error) {
if len(s) > 0 && s[len(s)-1] != '\n' { // We need a closing newline
return ReadRR(strings.NewReader(s+"\n"), "")
......@@ -114,125 +106,228 @@ func NewRR(s string) (RR, error) {
return ReadRR(strings.NewReader(s), "")
}
// ReadRR reads the RR contained in q.
// ReadRR reads the RR contained in r.
//
// The string file is used in error reporting and to resolve relative
// $INCLUDE directives.
//
// See NewRR for more documentation.
func ReadRR(q io.Reader, filename string) (RR, error) {
defttl := &ttlState{defaultTtl, false}
r := <-parseZoneHelper(q, ".", filename, defttl, 1)
if r == nil {
return nil, nil
}
if r.Error != nil {
return nil, r.Error
}
return r.RR, nil
func ReadRR(r io.Reader, file string) (RR, error) {
zp := NewZoneParser(r, ".", file)
zp.SetDefaultTTL(defaultTtl)
zp.SetIncludeAllowed(true)
rr, _ := zp.Next()
return rr, zp.Err()
}
// ParseZone reads a RFC 1035 style zonefile from r. It returns *Tokens on the
// returned channel, each consisting of either a parsed RR and optional comment
// or a nil RR and an error. The string file is only used
// in error reporting. The string origin is used as the initial origin, as
// if the file would start with an $ORIGIN directive.
// The directives $INCLUDE, $ORIGIN, $TTL and $GENERATE are supported.
// The channel t is closed by ParseZone when the end of r is reached.
// ZoneParser is a parser for an RFC 1035 style zonefile.
//
// Each parsed RR in the zone is returned sequentially from Next. An
// optional comment can be retrieved with Comment.
//
// The directives $INCLUDE, $ORIGIN, $TTL and $GENERATE are all
// supported. Although $INCLUDE is disabled by default.
// Note that $GENERATE's range support up to a maximum of 65535 steps.
//
// Basic usage pattern when reading from a string (z) containing the
// zone data:
//
// for x := range dns.ParseZone(strings.NewReader(z), "", "") {
// if x.Error != nil {
// // log.Println(x.Error)
// } else {
// // Do something with x.RR
// }
// zp := NewZoneParser(strings.NewReader(z), "", "")
//
// for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
// // Do something with rr
// }
//
// Comments specified after an RR (and on the same line!) are returned too:
// if err := zp.Err(); err != nil {
// // log.Println(err)
// }
//
// Comments specified after an RR (and on the same line!) are
// returned too:
//
// foo. IN A 10.0.0.1 ; this is a comment
//
// The text "; this is comment" is returned in Token.Comment. Comments inside the
// RR are discarded. Comments on a line by themselves are discarded too.
func ParseZone(r io.Reader, origin, file string) chan *Token {
return parseZoneHelper(r, origin, file, nil, 10000)
}
// The text "; this is comment" is returned from Comment. Comments inside
// the RR are returned concatenated along with the RR. Comments on a line
// by themselves are discarded.
//
// Callers should not assume all returned data in an Resource Record is
// syntactically correct, e.g. illegal base64 in RRSIGs will be returned as-is.
type ZoneParser struct {
c *zlexer
func parseZoneHelper(r io.Reader, origin, file string, defttl *ttlState, chansize int) chan *Token {
t := make(chan *Token, chansize)
go parseZone(r, origin, file, defttl, t, 0)
return t
}
parseErr *ParseError
func parseZone(r io.Reader, origin, f string, defttl *ttlState, t chan *Token, include int) {
defer func() {
if include == 0 {
close(t)
}
}()
s, cancel := scanInit(r)
c := make(chan lex)
// Start the lexer
go zlexer(s, c)
defer func() {
cancel()
// zlexer can send up to three tokens, the next one and possibly 2 remainders.
// Do a non-blocking read.
_, ok := <-c
_, ok = <-c
_, ok = <-c
if !ok {
// too bad
}
}()
// 6 possible beginnings of a line, _ is a space
// 0. zRRTYPE -> all omitted until the rrtype
// 1. zOwner _ zRrtype -> class/ttl omitted
// 2. zOwner _ zString _ zRrtype -> class omitted
// 3. zOwner _ zString _ zClass _ zRrtype -> ttl/class
// 4. zOwner _ zClass _ zRrtype -> ttl omitted
// 5. zOwner _ zClass _ zString _ zRrtype -> class/ttl (reversed)
// After detecting these, we know the zRrtype so we can jump to functions
// handling the rdata for each of these types.
origin string
file string
defttl *ttlState
h RR_Header
// sub is used to parse $INCLUDE files and $GENERATE directives.
// Next, by calling subNext, forwards the resulting RRs from this
// sub parser to the calling code.
sub *ZoneParser
osFile *os.File
includeDepth uint8
includeAllowed bool
generateDisallowed bool
}
// NewZoneParser returns an RFC 1035 style zonefile parser that reads
// from r.
//
// The string file is used in error reporting and to resolve relative
// $INCLUDE directives. The string origin is used as the initial
// origin, as if the file would start with an $ORIGIN directive.
func NewZoneParser(r io.Reader, origin, file string) *ZoneParser {
var pe *ParseError
if origin != "" {
origin = Fqdn(origin)
if _, ok := IsDomainName(origin); !ok {
t <- &Token{Error: &ParseError{f, "bad initial origin name", lex{}}}
return
pe = &ParseError{file, "bad initial origin name", lex{}}
}
}
return &ZoneParser{
c: newZLexer(r),
parseErr: pe,
origin: origin,
file: file,
}
}
// SetDefaultTTL sets the parsers default TTL to ttl.
func (zp *ZoneParser) SetDefaultTTL(ttl uint32) {
zp.defttl = &ttlState{ttl, false}
}
// SetIncludeAllowed controls whether $INCLUDE directives are
// allowed. $INCLUDE directives are not supported by default.
//
// The $INCLUDE directive will open and read from a user controlled
// file on the system. Even if the file is not a valid zonefile, the
// contents of the file may be revealed in error messages, such as:
//
// /etc/passwd: dns: not a TTL: "root:x:0:0:root:/root:/bin/bash" at line: 1:31
// /etc/shadow: dns: not a TTL: "root:$6$<redacted>::0:99999:7:::" at line: 1:125
func (zp *ZoneParser) SetIncludeAllowed(v bool) {
zp.includeAllowed = v
}
// Err returns the first non-EOF error that was encountered by the
// ZoneParser.
func (zp *ZoneParser) Err() error {
if zp.parseErr != nil {
return zp.parseErr
}
if zp.sub != nil {
if err := zp.sub.Err(); err != nil {
return err
}
}
return zp.c.Err()
}
func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) {
zp.parseErr = &ParseError{zp.file, err, l}
return nil, false
}
// Comment returns an optional text comment that occurred alongside
// the RR.
func (zp *ZoneParser) Comment() string {
if zp.parseErr != nil {
return ""
}
if zp.sub != nil {
return zp.sub.Comment()
}
return zp.c.Comment()
}
func (zp *ZoneParser) subNext() (RR, bool) {
if rr, ok := zp.sub.Next(); ok {
return rr, true
}
if zp.sub.osFile != nil {
zp.sub.osFile.Close()
zp.sub.osFile = nil
}
if zp.sub.Err() != nil {
// We have errors to surface.
return nil, false
}
zp.sub = nil
return zp.Next()
}
// Next advances the parser to the next RR in the zonefile and
// returns the (RR, true). It will return (nil, false) when the
// parsing stops, either by reaching the end of the input or an
// error. After Next returns (nil, false), the Err method will return
// any error that occurred during parsing.
func (zp *ZoneParser) Next() (RR, bool) {
if zp.parseErr != nil {
return nil, false
}
if zp.sub != nil {
return zp.subNext()
}
// 6 possible beginnings of a line (_ is a space):
//
// 0. zRRTYPE -> all omitted until the rrtype
// 1. zOwner _ zRrtype -> class/ttl omitted
// 2. zOwner _ zString _ zRrtype -> class omitted
// 3. zOwner _ zString _ zClass _ zRrtype -> ttl/class
// 4. zOwner _ zClass _ zRrtype -> ttl omitted
// 5. zOwner _ zClass _ zString _ zRrtype -> class/ttl (reversed)
//
// After detecting these, we know the zRrtype so we can jump to functions
// handling the rdata for each of these types.
st := zExpectOwnerDir // initial state
var h RR_Header
var prevName string
for l := range c {
// Lexer spotted an error already
if l.err == true {
t <- &Token{Error: &ParseError{f, l.token, l}}
return
h := &zp.h
for l, ok := zp.c.Next(); ok; l, ok = zp.c.Next() {
// zlexer spotted an error already
if l.err {
return zp.setParseError(l.token, l)
}
switch st {
case zExpectOwnerDir:
// We can also expect a directive, like $TTL or $ORIGIN
if defttl != nil {
h.Ttl = defttl.ttl
if zp.defttl != nil {
h.Ttl = zp.defttl.ttl
}
h.Class = ClassINET
switch l.value {
case zNewline:
st = zExpectOwnerDir
case zOwner:
h.Name = l.token
name, ok := toAbsoluteName(l.token, origin)
name, ok := toAbsoluteName(l.token, zp.origin)
if !ok {
t <- &Token{Error: &ParseError{f, "bad owner name", l}}
return
return zp.setParseError("bad owner name", l)
}
h.Name = name
prevName = h.Name
st = zExpectOwnerBl
case zDirTTL:
st = zExpectDirTTLBl
......@@ -243,12 +338,12 @@ func parseZone(r io.Reader, origin, f string, defttl *ttlState, t chan *Token, i
case zDirGenerate:
st = zExpectDirGenerateBl
case zRrtpe:
h.Name = prevName
h.Rrtype = l.torc
st = zExpectRdata
case zClass:
h.Name = prevName
h.Class = l.torc
st = zExpectAnyNoClassBl
case zBlank:
// Discard, can happen when there is nothing on the
......@@ -256,297 +351,477 @@ func parseZone(r io.Reader, origin, f string, defttl *ttlState, t chan *Token, i
case zString:
ttl, ok := stringToTTL(l.token)
if !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
return zp.setParseError("not a TTL", l)
}
h.Ttl = ttl
if defttl == nil || !defttl.isByDirective {
defttl = &ttlState{ttl, false}
if zp.defttl == nil || !zp.defttl.isByDirective {
zp.defttl = &ttlState{ttl, false}
}
st = zExpectAnyNoTTLBl
st = zExpectAnyNoTTLBl
default:
t <- &Token{Error: &ParseError{f, "syntax error at beginning", l}}
return
return zp.setParseError("syntax error at beginning", l)
}
case zExpectDirIncludeBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank after $INCLUDE-directive", l}}
return
return zp.setParseError("no blank after $INCLUDE-directive", l)
}
st = zExpectDirInclude
case zExpectDirInclude:
if l.value != zString {
t <- &Token{Error: &ParseError{f, "expecting $INCLUDE value, not this...", l}}
return
return zp.setParseError("expecting $INCLUDE value, not this...", l)
}
neworigin := origin // There may be optionally a new origin set after the filename, if not use current one
switch l := <-c; l.value {
neworigin := zp.origin // There may be optionally a new origin set after the filename, if not use current one
switch l, _ := zp.c.Next(); l.value {
case zBlank:
l := <-c
l, _ := zp.c.Next()
if l.value == zString {
name, ok := toAbsoluteName(l.token, origin)
name, ok := toAbsoluteName(l.token, zp.origin)
if !ok {
t <- &Token{Error: &ParseError{f, "bad origin name", l}}
return
return zp.setParseError("bad origin name", l)
}
neworigin = name
}
case zNewline, zEOF:
// Ok
default:
t <- &Token{Error: &ParseError{f, "garbage after $INCLUDE", l}}
return
return zp.setParseError("garbage after $INCLUDE", l)
}
if !zp.includeAllowed {
return zp.setParseError("$INCLUDE directive not allowed", l)
}
if zp.includeDepth >= maxIncludeDepth {
return zp.setParseError("too deeply nested $INCLUDE", l)
}
// Start with the new file
includePath := l.token
if !filepath.IsAbs(includePath) {
includePath = filepath.Join(filepath.Dir(f), includePath)
includePath = filepath.Join(filepath.Dir(zp.file), includePath)
}
r1, e1 := os.Open(includePath)
if e1 != nil {
msg := fmt.Sprintf("failed to open `%s'", l.token)
var as string
if !filepath.IsAbs(l.token) {
msg += fmt.Sprintf(" as `%s'", includePath)
as = fmt.Sprintf(" as `%s'", includePath)
}
t <- &Token{Error: &ParseError{f, msg, l}}
return
}
if include+1 > 7 {
t <- &Token{Error: &ParseError{f, "too deeply nested $INCLUDE", l}}
return
msg := fmt.Sprintf("failed to open `%s'%s: %v", l.token, as, e1)
return zp.setParseError(msg, l)
}
parseZone(r1, neworigin, includePath, defttl, t, include+1)
st = zExpectOwnerDir
zp.sub = NewZoneParser(r1, neworigin, includePath)
zp.sub.defttl, zp.sub.includeDepth, zp.sub.osFile = zp.defttl, zp.includeDepth+1, r1
zp.sub.SetIncludeAllowed(true)
return zp.subNext()
case zExpectDirTTLBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank after $TTL-directive", l}}
return
return zp.setParseError("no blank after $TTL-directive", l)
}
st = zExpectDirTTL
case zExpectDirTTL:
if l.value != zString {
t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}}
return
return zp.setParseError("expecting $TTL value, not this...", l)
}
if e, _ := slurpRemainder(c, f); e != nil {
t <- &Token{Error: e}
return
if err := slurpRemainder(zp.c); err != nil {
return zp.setParseError(err.err, err.lex)
}
ttl, ok := stringToTTL(l.token)
if !ok {
t <- &Token{Error: &ParseError{f, "expecting $TTL value, not this...", l}}
return
return zp.setParseError("expecting $TTL value, not this...", l)
}
defttl = &ttlState{ttl, true}
zp.defttl = &ttlState{ttl, true}
st = zExpectOwnerDir
case zExpectDirOriginBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank after $ORIGIN-directive", l}}
return
return zp.setParseError("no blank after $ORIGIN-directive", l)
}
st = zExpectDirOrigin
case zExpectDirOrigin:
if l.value != zString {
t <- &Token{Error: &ParseError{f, "expecting $ORIGIN value, not this...", l}}
return
return zp.setParseError("expecting $ORIGIN value, not this...", l)
}
if e, _ := slurpRemainder(c, f); e != nil {
t <- &Token{Error: e}
if err := slurpRemainder(zp.c); err != nil {
return zp.setParseError(err.err, err.lex)
}
name, ok := toAbsoluteName(l.token, origin)
name, ok := toAbsoluteName(l.token, zp.origin)
if !ok {
t <- &Token{Error: &ParseError{f, "bad origin name", l}}
return
return zp.setParseError("bad origin name", l)
}
origin = name
zp.origin = name
st = zExpectOwnerDir
case zExpectDirGenerateBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank after $GENERATE-directive", l}}
return
return zp.setParseError("no blank after $GENERATE-directive", l)
}
st = zExpectDirGenerate
case zExpectDirGenerate:
if l.value != zString {
t <- &Token{Error: &ParseError{f, "expecting $GENERATE value, not this...", l}}
return
if zp.generateDisallowed {
return zp.setParseError("nested $GENERATE directive not allowed", l)
}
if errMsg := generate(l, c, t, origin); errMsg != "" {
t <- &Token{Error: &ParseError{f, errMsg, l}}
return
if l.value != zString {
return zp.setParseError("expecting $GENERATE value, not this...", l)
}
st = zExpectOwnerDir
return zp.generate(l)
case zExpectOwnerBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank after owner", l}}
return
return zp.setParseError("no blank after owner", l)
}
st = zExpectAny
case zExpectAny:
switch l.value {
case zRrtpe:
if defttl == nil {
t <- &Token{Error: &ParseError{f, "missing TTL with no previous value", l}}
return
if zp.defttl == nil {
return zp.setParseError("missing TTL with no previous value", l)
}
h.Rrtype = l.torc
st = zExpectRdata
case zClass:
h.Class = l.torc
st = zExpectAnyNoClassBl
case zString:
ttl, ok := stringToTTL(l.token)
if !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
return zp.setParseError("not a TTL", l)
}
h.Ttl = ttl
if defttl == nil || !defttl.isByDirective {
defttl = &ttlState{ttl, false}
if zp.defttl == nil || !zp.defttl.isByDirective {
zp.defttl = &ttlState{ttl, false}
}
st = zExpectAnyNoTTLBl
default:
t <- &Token{Error: &ParseError{f, "expecting RR type, TTL or class, not this...", l}}
return
return zp.setParseError("expecting RR type, TTL or class, not this...", l)
}
case zExpectAnyNoClassBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank before class", l}}
return
return zp.setParseError("no blank before class", l)
}
st = zExpectAnyNoClass
case zExpectAnyNoTTLBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank before TTL", l}}
return
return zp.setParseError("no blank before TTL", l)
}
st = zExpectAnyNoTTL
case zExpectAnyNoTTL:
switch l.value {
case zClass:
h.Class = l.torc
st = zExpectRrtypeBl
case zRrtpe:
h.Rrtype = l.torc
st = zExpectRdata
default:
t <- &Token{Error: &ParseError{f, "expecting RR type or class, not this...", l}}
return
return zp.setParseError("expecting RR type or class, not this...", l)
}
case zExpectAnyNoClass:
switch l.value {
case zString:
ttl, ok := stringToTTL(l.token)
if !ok {
t <- &Token{Error: &ParseError{f, "not a TTL", l}}
return
return zp.setParseError("not a TTL", l)
}
h.Ttl = ttl
if defttl == nil || !defttl.isByDirective {
defttl = &ttlState{ttl, false}
if zp.defttl == nil || !zp.defttl.isByDirective {
zp.defttl = &ttlState{ttl, false}
}
st = zExpectRrtypeBl
case zRrtpe:
h.Rrtype = l.torc
st = zExpectRdata
default:
t <- &Token{Error: &ParseError{f, "expecting RR type or TTL, not this...", l}}
return
return zp.setParseError("expecting RR type or TTL, not this...", l)
}
case zExpectRrtypeBl:
if l.value != zBlank {
t <- &Token{Error: &ParseError{f, "no blank before RR type", l}}
return
return zp.setParseError("no blank before RR type", l)
}
st = zExpectRrtype
case zExpectRrtype:
if l.value != zRrtpe {
t <- &Token{Error: &ParseError{f, "unknown RR type", l}}
return
return zp.setParseError("unknown RR type", l)
}
h.Rrtype = l.torc
st = zExpectRdata
case zExpectRdata:
r, e, c1 := setRR(h, c, origin, f)
if e != nil {
// If e.lex is nil than we have encounter a unknown RR type
// in that case we substitute our current lex token
if e.lex.token == "" && e.lex.value == 0 {
e.lex = l // Uh, dirty
var (
rr RR
parseAsRFC3597 bool
)
if newFn, ok := TypeToRR[h.Rrtype]; ok {
rr = newFn()
*rr.Header() = *h
// We may be parsing a known RR type using the RFC3597 format.
// If so, we handle that here in a generic way.
//
// This is also true for PrivateRR types which will have the
// RFC3597 parsing done for them and the Unpack method called
// to populate the RR instead of simply deferring to Parse.
if zp.c.Peek().token == "\\#" {
parseAsRFC3597 = true
}
t <- &Token{Error: e}
return
} else {
rr = &RFC3597{Hdr: *h}
}
t <- &Token{RR: r, Comment: c1}
st = zExpectOwnerDir
_, isPrivate := rr.(*PrivateRR)
if !isPrivate && zp.c.Peek().token == "" {
// This is a dynamic update rr.
// TODO(tmthrgd): Previously slurpRemainder was only called
// for certain RR types, which may have been important.
if err := slurpRemainder(zp.c); err != nil {
return zp.setParseError(err.err, err.lex)
}
return rr, true
} else if l.value == zNewline {
return zp.setParseError("unexpected newline", l)
}
parseAsRR := rr
if parseAsRFC3597 {
parseAsRR = &RFC3597{Hdr: *h}
}
if err := parseAsRR.parse(zp.c, zp.origin); err != nil {
// err is a concrete *ParseError without the file field set.
// The setParseError call below will construct a new
// *ParseError with file set to zp.file.
// err.lex may be nil in which case we substitute our current
// lex token.
if err.lex == (lex{}) {
return zp.setParseError(err.err, l)
}
return zp.setParseError(err.err, err.lex)
}
if parseAsRFC3597 {
err := parseAsRR.(*RFC3597).fromRFC3597(rr)
if err != nil {
return zp.setParseError(err.Error(), l)
}
}
return rr, true
}
}
// If we get here, we and the h.Rrtype is still zero, we haven't parsed anything, this
// is not an error, because an empty zone file is still a zone file.
return nil, false
}
type zlexer struct {
br io.ByteReader
readErr error
line int
column int
comBuf string
comment string
l lex
cachedL *lex
brace int
quote bool
space bool
commt bool
rrtype bool
owner bool
nextL bool
eol bool // end-of-line
}
func newZLexer(r io.Reader) *zlexer {
br, ok := r.(io.ByteReader)
if !ok {
br = bufio.NewReaderSize(r, 1024)
}
return &zlexer{
br: br,
line: 1,
owner: true,
}
}
// zlexer scans the sourcefile and returns tokens on the channel c.
func zlexer(s *scan, c chan lex) {
var l lex
str := make([]byte, maxTok) // Should be enough for any token
stri := 0 // Offset in str (0 means empty)
com := make([]byte, maxTok) // Hold comment text
comi := 0
quote := false
escape := false
space := false
commt := false
rrtype := false
owner := true
brace := 0
x, err := s.tokenText()
defer close(c)
for err == nil {
l.column = s.position.Column
l.line = s.position.Line
if stri >= maxTok {
func (zl *zlexer) Err() error {
if zl.readErr == io.EOF {
return nil
}
return zl.readErr
}
// readByte returns the next byte from the input
func (zl *zlexer) readByte() (byte, bool) {
if zl.readErr != nil {
return 0, false
}
c, err := zl.br.ReadByte()
if err != nil {
zl.readErr = err
return 0, false
}
// delay the newline handling until the next token is delivered,
// fixes off-by-one errors when reporting a parse error.
if zl.eol {
zl.line++
zl.column = 0
zl.eol = false
}
if c == '\n' {
zl.eol = true
} else {
zl.column++
}
return c, true
}
func (zl *zlexer) Peek() lex {
if zl.nextL {
return zl.l
}
l, ok := zl.Next()
if !ok {
return l
}
if zl.nextL {
// Cache l. Next returns zl.cachedL then zl.l.
zl.cachedL = &l
} else {
// In this case l == zl.l, so we just tell Next to return zl.l.
zl.nextL = true
}
return l
}
func (zl *zlexer) Next() (lex, bool) {
l := &zl.l
switch {
case zl.cachedL != nil:
l, zl.cachedL = zl.cachedL, nil
return *l, true
case zl.nextL:
zl.nextL = false
return *l, true
case l.err:
// Parsing errors should be sticky.
return lex{value: zEOF}, false
}
var (
str [maxTok]byte // Hold string text
com [maxTok]byte // Hold comment text
stri int // Offset in str (0 means empty)
comi int // Offset in com (0 means empty)
escape bool
)
if zl.comBuf != "" {
comi = copy(com[:], zl.comBuf)
zl.comBuf = ""
}
zl.comment = ""
for x, ok := zl.readByte(); ok; x, ok = zl.readByte() {
l.line, l.column = zl.line, zl.column
if stri >= len(str) {
l.token = "token length insufficient for parsing"
l.err = true
c <- l
return
return *l, true
}
if comi >= maxTok {
if comi >= len(com) {
l.token = "comment length insufficient for parsing"
l.err = true
c <- l
return
return *l, true
}
switch x {
case ' ', '\t':
if escape {
escape = false
str[stri] = x
stri++
break
}
if quote {
// Inside quotes this is legal
if escape || zl.quote {
// Inside quotes or escaped this is legal.
str[stri] = x
stri++
escape = false
break
}
if commt {
if zl.commt {
com[comi] = x
comi++
break
}
var retL lex
if stri == 0 {
// Space directly in the beginning, handled in the grammar
} else if owner {
} else if zl.owner {
// If we have a string and its the first, make it an owner
l.value = zOwner
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
// escape $... start with a \ not a $, so this will work
switch l.tokenUpper {
switch strings.ToUpper(l.token) {
case "$TTL":
l.value = zDirTTL
case "$ORIGIN":
......@@ -556,259 +831,328 @@ func zlexer(s *scan, c chan lex) {
case "$GENERATE":
l.value = zDirGenerate
}
c <- l
retL = *l
} else {
l.value = zString
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
if !rrtype {
if t, ok := StringToType[l.tokenUpper]; ok {
if !zl.rrtype {
tokenUpper := strings.ToUpper(l.token)
if t, ok := StringToType[tokenUpper]; ok {
l.value = zRrtpe
l.torc = t
rrtype = true
} else {
if strings.HasPrefix(l.tokenUpper, "TYPE") {
t, ok := typeToInt(l.token)
if !ok {
l.token = "unknown RR type"
l.err = true
c <- l
return
}
l.value = zRrtpe
rrtype = true
l.torc = t
zl.rrtype = true
} else if strings.HasPrefix(tokenUpper, "TYPE") {
t, ok := typeToInt(l.token)
if !ok {
l.token = "unknown RR type"
l.err = true
return *l, true
}
l.value = zRrtpe
l.torc = t
zl.rrtype = true
}
if t, ok := StringToClass[l.tokenUpper]; ok {
if t, ok := StringToClass[tokenUpper]; ok {
l.value = zClass
l.torc = t
} else {
if strings.HasPrefix(l.tokenUpper, "CLASS") {
t, ok := classToInt(l.token)
if !ok {
l.token = "unknown class"
l.err = true
c <- l
return
}
l.value = zClass
l.torc = t
} else if strings.HasPrefix(tokenUpper, "CLASS") {
t, ok := classToInt(l.token)
if !ok {
l.token = "unknown class"
l.err = true
return *l, true
}
l.value = zClass
l.torc = t
}
}
c <- l
retL = *l
}
stri = 0
if !space && !commt {
zl.owner = false
if !zl.space {
zl.space = true
l.value = zBlank
l.token = " "
l.length = 1
c <- l
if retL == (lex{}) {
return *l, true
}
zl.nextL = true
}
if retL != (lex{}) {
return retL, true
}
owner = false
space = true
case ';':
if escape {
escape = false
if escape || zl.quote {
// Inside quotes or escaped this is legal.
str[stri] = x
stri++
escape = false
break
}
if quote {
// Inside quotes this is legal
str[stri] = x
stri++
break
zl.commt = true
zl.comBuf = ""
if comi > 1 {
// A newline was previously seen inside a comment that
// was inside braces and we delayed adding it until now.
com[comi] = ' ' // convert newline to space
comi++
if comi >= len(com) {
l.token = "comment length insufficient for parsing"
l.err = true
return *l, true
}
}
com[comi] = ';'
comi++
if stri > 0 {
zl.comBuf = string(com[:comi])
l.value = zString
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
c <- l
stri = 0
return *l, true
}
commt = true
com[comi] = ';'
comi++
case '\r':
escape = false
if quote {
if zl.quote {
str[stri] = x
stri++
break
}
// discard if outside of quotes
case '\n':
escape = false
// Escaped newline
if quote {
if zl.quote {
str[stri] = x
stri++
break
}
// inside quotes this is legal
if commt {
if zl.commt {
// Reset a comment
commt = false
rrtype = false
stri = 0
zl.commt = false
zl.rrtype = false
// If not in a brace this ends the comment AND the RR
if brace == 0 {
owner = true
owner = true
if zl.brace == 0 {
zl.owner = true
l.value = zNewline
l.token = "\n"
l.tokenUpper = l.token
l.length = 1
l.comment = string(com[:comi])
c <- l
l.comment = ""
comi = 0
break
zl.comment = string(com[:comi])
return *l, true
}
com[comi] = ' ' // convert newline to space
comi++
zl.comBuf = string(com[:comi])
break
}
if brace == 0 {
if zl.brace == 0 {
// If there is previous text, we should output it here
var retL lex
if stri != 0 {
l.value = zString
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
if !rrtype {
if t, ok := StringToType[l.tokenUpper]; ok {
if !zl.rrtype {
tokenUpper := strings.ToUpper(l.token)
if t, ok := StringToType[tokenUpper]; ok {
zl.rrtype = true
l.value = zRrtpe
l.torc = t
rrtype = true
}
}
c <- l
retL = *l
}
l.value = zNewline
l.token = "\n"
l.tokenUpper = l.token
l.length = 1
c <- l
stri = 0
commt = false
rrtype = false
owner = true
comi = 0
zl.comment = zl.comBuf
zl.comBuf = ""
zl.rrtype = false
zl.owner = true
if retL != (lex{}) {
zl.nextL = true
return retL, true
}
return *l, true
}
case '\\':
// comments do not get escaped chars, everything is copied
if commt {
if zl.commt {
com[comi] = x
comi++
break
}
// something already escaped must be in string
if escape {
str[stri] = x
stri++
escape = false
break
}
// something escaped outside of string gets added to string
str[stri] = x
stri++
escape = true
case '"':
if commt {
if zl.commt {
com[comi] = x
comi++
break
}
if escape {
str[stri] = x
stri++
escape = false
break
}
space = false
zl.space = false
// send previous gathered text and the quote
var retL lex
if stri != 0 {
l.value = zString
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
c <- l
stri = 0
retL = *l
}
// send quote itself as separate token
l.value = zQuote
l.token = "\""
l.tokenUpper = l.token
l.length = 1
c <- l
quote = !quote
zl.quote = !zl.quote
if retL != (lex{}) {
zl.nextL = true
return retL, true
}
return *l, true
case '(', ')':
if commt {
if zl.commt {
com[comi] = x
comi++
break
}
if escape {
if escape || zl.quote {
// Inside quotes or escaped this is legal.
str[stri] = x
stri++
escape = false
break
}
if quote {
str[stri] = x
stri++
break
}
switch x {
case ')':
brace--
if brace < 0 {
zl.brace--
if zl.brace < 0 {
l.token = "extra closing brace"
l.tokenUpper = l.token
l.err = true
c <- l
return
return *l, true
}
case '(':
brace++
zl.brace++
}
default:
escape = false
if commt {
if zl.commt {
com[comi] = x
comi++
break
}
str[stri] = x
stri++
space = false
zl.space = false
}
x, err = s.tokenText()
}
if zl.readErr != nil && zl.readErr != io.EOF {
// Don't return any tokens after a read error occurs.
return lex{value: zEOF}, false
}
var retL lex
if stri > 0 {
// Send remainder
l.token = string(str[:stri])
l.tokenUpper = strings.ToUpper(l.token)
l.length = stri
// Send remainder of str
l.value = zString
c <- l
l.token = string(str[:stri])
retL = *l
if comi <= 0 {
return retL, true
}
}
if brace != 0 {
if comi > 0 {
// Send remainder of com
l.value = zNewline
l.token = "\n"
zl.comment = string(com[:comi])
if retL != (lex{}) {
zl.nextL = true
return retL, true
}
return *l, true
}
if zl.brace != 0 {
l.token = "unbalanced brace"
l.tokenUpper = l.token
l.err = true
c <- l
return *l, true
}
return lex{value: zEOF}, false
}
func (zl *zlexer) Comment() string {
if zl.l.err {
return ""
}
return zl.comment
}
// Extract the class number from CLASSxx
......@@ -839,8 +1183,7 @@ func typeToInt(token string) (uint16, bool) {
// stringToTTL parses things like 2w, 2m, etc, and returns the time in seconds.
func stringToTTL(token string) (uint32, bool) {
s := uint32(0)
i := uint32(0)
var s, i uint32
for _, c := range token {
switch c {
case 's', 'S':
......@@ -883,11 +1226,29 @@ func stringToCm(token string) (e, m uint8, ok bool) {
if cmeters, err = strconv.Atoi(s[1]); err != nil {
return
}
// There's no point in having more than 2 digits in this part, and would rather make the implementation complicated ('123' should be treated as '12').
// So we simply reject it.
// We also make sure the first character is a digit to reject '+-' signs.
if len(s[1]) > 2 || s[1][0] < '0' || s[1][0] > '9' {
return
}
if len(s[1]) == 1 {
// 'nn.1' must be treated as 'nn-meters and 10cm, not 1cm.
cmeters *= 10
}
if s[0] == "" {
// This will allow omitting the 'meter' part, like .01 (meaning 0.01m = 1cm).
break
}
fallthrough
case 1:
if meters, err = strconv.Atoi(s[0]); err != nil {
return
}
// RFC1876 states the max value is 90000000.00. The latter two conditions enforce it.
if s[0][0] < '0' || s[0][0] > '9' || meters > 90000000 || (meters == 90000000 && cmeters != 0) {
return
}
case 0:
// huh?
return 0, 0, false
......@@ -900,13 +1261,10 @@ func stringToCm(token string) (e, m uint8, ok bool) {
e = 0
val = cmeters
}
for val > 10 {
for val >= 10 {
e++
val /= 10
}
if e > 9 {
ok = false
}
m = uint8(val)
return
}
......@@ -928,7 +1286,7 @@ func toAbsoluteName(name, origin string) (absolute string, ok bool) {
}
// check if name is already absolute
if name[len(name)-1] == '.' {
if IsFqdn(name) {
return name, true
}
......@@ -948,6 +1306,9 @@ func appendOrigin(name, origin string) string {
// LOC record helper function
func locCheckNorth(token string, latitude uint32) (uint32, bool) {
if latitude > 90*1000*60*60 {
return latitude, false
}
switch token {
case "n", "N":
return LOC_EQUATOR + latitude, true
......@@ -959,6 +1320,9 @@ func locCheckNorth(token string, latitude uint32) (uint32, bool) {
// LOC record helper function
func locCheckEast(token string, longitude uint32) (uint32, bool) {
if longitude > 180*1000*60*60 {
return longitude, false
}
switch token {
case "e", "E":
return LOC_EQUATOR + longitude, true
......@@ -968,24 +1332,21 @@ func locCheckEast(token string, longitude uint32) (uint32, bool) {
return longitude, false
}
// "Eat" the rest of the "line". Return potential comments
func slurpRemainder(c chan lex, f string) (*ParseError, string) {
l := <-c
com := ""
// "Eat" the rest of the "line"
func slurpRemainder(c *zlexer) *ParseError {
l, _ := c.Next()
switch l.value {
case zBlank:
l = <-c
com = l.comment
l, _ = c.Next()
if l.value != zNewline && l.value != zEOF {
return &ParseError{f, "garbage after rdata", l}, ""
return &ParseError{"", "garbage after rdata", l}
}
case zNewline:
com = l.comment
case zEOF:
default:
return &ParseError{f, "garbage after rdata", l}, ""
return &ParseError{"", "garbage after rdata", l}
}
return nil, com
return nil
}
// Parse a 64 bit-like ipv6 address: "0014:4fff:ff20:ee64"
......@@ -994,7 +1355,7 @@ func stringToNodeID(l lex) (uint64, *ParseError) {
if len(l.token) < 19 {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
}
// There must be three colons at fixes postitions, if not its a parse error
// There must be three colons at fixes positions, if not its a parse error
if l.token[4] != ':' && l.token[9] != ':' && l.token[14] != ':' {
return 0, &ParseError{l.token, "bad NID/L64 NodeID/Locator64", l}
}
......
package dns
import (
"bytes"
"encoding/base64"
"net"
"strconv"
"strings"
)
type parserFunc struct {
// Func defines the function that parses the tokens and returns the RR
// or an error. The last string contains any comments in the line as
// they returned by the lexer as well.
Func func(h RR_Header, c chan lex, origin string, file string) (RR, *ParseError, string)
// Signals if the RR ending is of variable length, like TXT or records
// that have Hexadecimal or Base64 as their last element in the Rdata. Records
// that have a fixed ending or for instance A, AAAA, SOA and etc.
Variable bool
}
// Parse the rdata of each rrtype.
// All data from the channel c is either zString or zBlank.
// After the rdata there may come a zBlank and then a zNewline
// or immediately a zNewline. If this is not the case we flag
// an *ParseError: garbage after rdata.
func setRR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
parserfunc, ok := typeToparserFunc[h.Rrtype]
if ok {
r, e, cm := parserfunc.Func(h, c, o, f)
if parserfunc.Variable {
return r, e, cm
}
if e != nil {
return nil, e, ""
}
e, cm = slurpRemainder(c, f)
if e != nil {
return nil, e, ""
}
return r, nil, cm
}
// RFC3957 RR (Unknown RR handling)
return setRFC3597(h, c, o, f)
}
// A remainder of the rdata with embedded spaces, return the parsed string (sans the spaces)
// or an error
func endingToString(c chan lex, errstr, f string) (string, *ParseError, string) {
s := ""
l := <-c // zString
func endingToString(c *zlexer, errstr string) (string, *ParseError) {
var buffer bytes.Buffer
l, _ := c.Next() // zString
for l.value != zNewline && l.value != zEOF {
if l.err {
return s, &ParseError{f, errstr, l}, ""
return buffer.String(), &ParseError{"", errstr, l}
}
switch l.value {
case zString:
s += l.token
buffer.WriteString(l.token)
case zBlank: // Ok
default:
return "", &ParseError{f, errstr, l}, ""
return "", &ParseError{"", errstr, l}
}
l = <-c
l, _ = c.Next()
}
return s, nil, l.comment
return buffer.String(), nil
}
// A remainder of the rdata with embedded spaces, split on unquoted whitespace
// and return the parsed string slice or an error
func endingToTxtSlice(c chan lex, errstr, f string) ([]string, *ParseError, string) {
func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) {
// Get the remaining data until we see a zNewline
l := <-c
l, _ := c.Next()
if l.err {
return nil, &ParseError{f, errstr, l}, ""
return nil, &ParseError{"", errstr, l}
}
// Build the slice
......@@ -79,7 +45,7 @@ func endingToTxtSlice(c chan lex, errstr, f string) ([]string, *ParseError, stri
empty := false
for l.value != zNewline && l.value != zEOF {
if l.err {
return nil, &ParseError{f, errstr, l}, ""
return nil, &ParseError{"", errstr, l}
}
switch l.value {
case zString:
......@@ -106,7 +72,7 @@ func endingToTxtSlice(c chan lex, errstr, f string) ([]string, *ParseError, stri
case zBlank:
if quote {
// zBlank can only be seen in between txt parts.
return nil, &ParseError{f, errstr, l}, ""
return nil, &ParseError{"", errstr, l}
}
case zQuote:
if empty && quote {
......@@ -115,196 +81,133 @@ func endingToTxtSlice(c chan lex, errstr, f string) ([]string, *ParseError, stri
quote = !quote
empty = true
default:
return nil, &ParseError{f, errstr, l}, ""
return nil, &ParseError{"", errstr, l}
}
l = <-c
l, _ = c.Next()
}
if quote {
return nil, &ParseError{f, errstr, l}, ""
return nil, &ParseError{"", errstr, l}
}
return s, nil, l.comment
}
func setA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(A)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
return s, nil
}
func (rr *A) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
rr.A = net.ParseIP(l.token)
if rr.A == nil || l.err {
return nil, &ParseError{f, "bad A A", l}, ""
// IPv4 addresses cannot include ":".
// We do this rather than use net.IP's To4() because
// To4() treats IPv4-mapped IPv6 addresses as being
// IPv4.
isIPv4 := !strings.Contains(l.token, ":")
if rr.A == nil || !isIPv4 || l.err {
return &ParseError{"", "bad A A", l}
}
return rr, nil, ""
return slurpRemainder(c)
}
func setAAAA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(AAAA)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *AAAA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
rr.AAAA = net.ParseIP(l.token)
if rr.AAAA == nil || l.err {
return nil, &ParseError{f, "bad AAAA AAAA", l}, ""
// IPv6 addresses must include ":", and IPv4
// addresses cannot include ":".
isIPv6 := strings.Contains(l.token, ":")
if rr.AAAA == nil || !isIPv6 || l.err {
return &ParseError{"", "bad AAAA AAAA", l}
}
return rr, nil, ""
return slurpRemainder(c)
}
func setNS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NS)
rr.Hdr = h
l := <-c
rr.Ns = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *NS) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad NS Ns", l}, ""
return &ParseError{"", "bad NS Ns", l}
}
rr.Ns = name
return rr, nil, ""
return slurpRemainder(c)
}
func setPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(PTR)
rr.Hdr = h
l := <-c
rr.Ptr = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *PTR) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad PTR Ptr", l}, ""
return &ParseError{"", "bad PTR Ptr", l}
}
rr.Ptr = name
return rr, nil, ""
return slurpRemainder(c)
}
func setNSAPPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NSAPPTR)
rr.Hdr = h
l := <-c
rr.Ptr = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *NSAPPTR) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad NSAP-PTR Ptr", l}, ""
return &ParseError{"", "bad NSAP-PTR Ptr", l}
}
rr.Ptr = name
return rr, nil, ""
return slurpRemainder(c)
}
func setRP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RP)
rr.Hdr = h
l := <-c
rr.Mbox = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *RP) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
mbox, mboxOk := toAbsoluteName(l.token, o)
if l.err || !mboxOk {
return nil, &ParseError{f, "bad RP Mbox", l}, ""
return &ParseError{"", "bad RP Mbox", l}
}
rr.Mbox = mbox
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
rr.Txt = l.token
txt, txtOk := toAbsoluteName(l.token, o)
if l.err || !txtOk {
return nil, &ParseError{f, "bad RP Txt", l}, ""
return &ParseError{"", "bad RP Txt", l}
}
rr.Txt = txt
return rr, nil, ""
return slurpRemainder(c)
}
func setMR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MR)
rr.Hdr = h
l := <-c
rr.Mr = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MR) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MR Mr", l}, ""
return &ParseError{"", "bad MR Mr", l}
}
rr.Mr = name
return rr, nil, ""
return slurpRemainder(c)
}
func setMB(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MB)
rr.Hdr = h
l := <-c
rr.Mb = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MB) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MB Mb", l}, ""
return &ParseError{"", "bad MB Mb", l}
}
rr.Mb = name
return rr, nil, ""
return slurpRemainder(c)
}
func setMG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MG)
rr.Hdr = h
l := <-c
rr.Mg = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MG) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MG Mg", l}, ""
return &ParseError{"", "bad MG Mg", l}
}
rr.Mg = name
return rr, nil, ""
return slurpRemainder(c)
}
func setHINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(HINFO)
rr.Hdr = h
chunks, e, c1 := endingToTxtSlice(c, "bad HINFO Fields", f)
func (rr *HINFO) parse(c *zlexer, o string) *ParseError {
chunks, e := endingToTxtSlice(c, "bad HINFO Fields")
if e != nil {
return nil, e, c1
return e
}
if ln := len(chunks); ln == 0 {
return rr, nil, ""
return nil
} else if ln == 1 {
// Can we split it?
if out := strings.Fields(chunks[0]); len(out) > 1 {
......@@ -317,281 +220,198 @@ func setHINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr.Cpu = chunks[0]
rr.Os = strings.Join(chunks[1:], " ")
return rr, nil, ""
return nil
}
func setMINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MINFO)
rr.Hdr = h
l := <-c
rr.Rmail = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MINFO) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
rmail, rmailOk := toAbsoluteName(l.token, o)
if l.err || !rmailOk {
return nil, &ParseError{f, "bad MINFO Rmail", l}, ""
return &ParseError{"", "bad MINFO Rmail", l}
}
rr.Rmail = rmail
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
rr.Email = l.token
email, emailOk := toAbsoluteName(l.token, o)
if l.err || !emailOk {
return nil, &ParseError{f, "bad MINFO Email", l}, ""
return &ParseError{"", "bad MINFO Email", l}
}
rr.Email = email
return rr, nil, ""
return slurpRemainder(c)
}
func setMF(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MF)
rr.Hdr = h
l := <-c
rr.Mf = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MF) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MF Mf", l}, ""
return &ParseError{"", "bad MF Mf", l}
}
rr.Mf = name
return rr, nil, ""
return slurpRemainder(c)
}
func setMD(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MD)
rr.Hdr = h
l := <-c
rr.Md = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MD) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MD Md", l}, ""
return &ParseError{"", "bad MD Md", l}
}
rr.Md = name
return rr, nil, ""
return slurpRemainder(c)
}
func setMX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(MX)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *MX) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad MX Pref", l}, ""
return &ParseError{"", "bad MX Pref", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Mx = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad MX Mx", l}, ""
return &ParseError{"", "bad MX Mx", l}
}
rr.Mx = name
return rr, nil, ""
return slurpRemainder(c)
}
func setRT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RT)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *RT) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil {
return nil, &ParseError{f, "bad RT Preference", l}, ""
return &ParseError{"", "bad RT Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Host = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad RT Host", l}, ""
return &ParseError{"", "bad RT Host", l}
}
rr.Host = name
return rr, nil, ""
return slurpRemainder(c)
}
func setAFSDB(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(AFSDB)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *AFSDB) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad AFSDB Subtype", l}, ""
return &ParseError{"", "bad AFSDB Subtype", l}
}
rr.Subtype = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Hostname = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad AFSDB Hostname", l}, ""
return &ParseError{"", "bad AFSDB Hostname", l}
}
rr.Hostname = name
return rr, nil, ""
return slurpRemainder(c)
}
func setX25(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(X25)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *X25) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
if l.err {
return nil, &ParseError{f, "bad X25 PSDNAddress", l}, ""
return &ParseError{"", "bad X25 PSDNAddress", l}
}
rr.PSDNAddress = l.token
return rr, nil, ""
return slurpRemainder(c)
}
func setKX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(KX)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *KX) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad KX Pref", l}, ""
return &ParseError{"", "bad KX Pref", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Exchanger = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad KX Exchanger", l}, ""
return &ParseError{"", "bad KX Exchanger", l}
}
rr.Exchanger = name
return rr, nil, ""
return slurpRemainder(c)
}
func setCNAME(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(CNAME)
rr.Hdr = h
l := <-c
rr.Target = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *CNAME) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad CNAME Target", l}, ""
return &ParseError{"", "bad CNAME Target", l}
}
rr.Target = name
return rr, nil, ""
return slurpRemainder(c)
}
func setDNAME(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(DNAME)
rr.Hdr = h
l := <-c
rr.Target = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *DNAME) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad DNAME Target", l}, ""
return &ParseError{"", "bad DNAME Target", l}
}
rr.Target = name
return rr, nil, ""
return slurpRemainder(c)
}
func setSOA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(SOA)
rr.Hdr = h
l := <-c
rr.Ns = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *SOA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
ns, nsOk := toAbsoluteName(l.token, o)
if l.err || !nsOk {
return nil, &ParseError{f, "bad SOA Ns", l}, ""
return &ParseError{"", "bad SOA Ns", l}
}
rr.Ns = ns
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
rr.Mbox = l.token
mbox, mboxOk := toAbsoluteName(l.token, o)
if l.err || !mboxOk {
return nil, &ParseError{f, "bad SOA Mbox", l}, ""
return &ParseError{"", "bad SOA Mbox", l}
}
rr.Mbox = mbox
<-c // zBlank
c.Next() // zBlank
var (
v uint32
ok bool
)
for i := 0; i < 5; i++ {
l = <-c
l, _ = c.Next()
if l.err {
return nil, &ParseError{f, "bad SOA zone parameter", l}, ""
return &ParseError{"", "bad SOA zone parameter", l}
}
if j, e := strconv.ParseUint(l.token, 10, 32); e != nil {
if j, err := strconv.ParseUint(l.token, 10, 32); err != nil {
if i == 0 {
// Serial must be a number
return nil, &ParseError{f, "bad SOA zone parameter", l}, ""
return &ParseError{"", "bad SOA zone parameter", l}
}
// We allow other fields to be unitful duration strings
if v, ok = stringToTTL(l.token); !ok {
return nil, &ParseError{f, "bad SOA zone parameter", l}, ""
return &ParseError{"", "bad SOA zone parameter", l}
}
} else {
......@@ -600,452 +420,407 @@ func setSOA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
switch i {
case 0:
rr.Serial = v
<-c // zBlank
c.Next() // zBlank
case 1:
rr.Refresh = v
<-c // zBlank
c.Next() // zBlank
case 2:
rr.Retry = v
<-c // zBlank
c.Next() // zBlank
case 3:
rr.Expire = v
<-c // zBlank
c.Next() // zBlank
case 4:
rr.Minttl = v
}
}
return rr, nil, ""
return slurpRemainder(c)
}
func setSRV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(SRV)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *SRV) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad SRV Priority", l}, ""
return &ParseError{"", "bad SRV Priority", l}
}
rr.Priority = uint16(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad SRV Weight", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e1 := strconv.ParseUint(l.token, 10, 16)
if e1 != nil || l.err {
return &ParseError{"", "bad SRV Weight", l}
}
rr.Weight = uint16(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad SRV Port", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e2 := strconv.ParseUint(l.token, 10, 16)
if e2 != nil || l.err {
return &ParseError{"", "bad SRV Port", l}
}
rr.Port = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Target = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad SRV Target", l}, ""
return &ParseError{"", "bad SRV Target", l}
}
rr.Target = name
return rr, nil, ""
return slurpRemainder(c)
}
func setNAPTR(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NAPTR)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *NAPTR) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad NAPTR Order", l}, ""
return &ParseError{"", "bad NAPTR Order", l}
}
rr.Order = uint16(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad NAPTR Preference", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e1 := strconv.ParseUint(l.token, 10, 16)
if e1 != nil || l.err {
return &ParseError{"", "bad NAPTR Preference", l}
}
rr.Preference = uint16(i)
// Flags
<-c // zBlank
l = <-c // _QUOTE
c.Next() // zBlank
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Flags", l}, ""
return &ParseError{"", "bad NAPTR Flags", l}
}
l = <-c // Either String or Quote
l, _ = c.Next() // Either String or Quote
if l.value == zString {
rr.Flags = l.token
l = <-c // _QUOTE
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Flags", l}, ""
return &ParseError{"", "bad NAPTR Flags", l}
}
} else if l.value == zQuote {
rr.Flags = ""
} else {
return nil, &ParseError{f, "bad NAPTR Flags", l}, ""
return &ParseError{"", "bad NAPTR Flags", l}
}
// Service
<-c // zBlank
l = <-c // _QUOTE
c.Next() // zBlank
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Service", l}, ""
return &ParseError{"", "bad NAPTR Service", l}
}
l = <-c // Either String or Quote
l, _ = c.Next() // Either String or Quote
if l.value == zString {
rr.Service = l.token
l = <-c // _QUOTE
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Service", l}, ""
return &ParseError{"", "bad NAPTR Service", l}
}
} else if l.value == zQuote {
rr.Service = ""
} else {
return nil, &ParseError{f, "bad NAPTR Service", l}, ""
return &ParseError{"", "bad NAPTR Service", l}
}
// Regexp
<-c // zBlank
l = <-c // _QUOTE
c.Next() // zBlank
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Regexp", l}, ""
return &ParseError{"", "bad NAPTR Regexp", l}
}
l = <-c // Either String or Quote
l, _ = c.Next() // Either String or Quote
if l.value == zString {
rr.Regexp = l.token
l = <-c // _QUOTE
l, _ = c.Next() // _QUOTE
if l.value != zQuote {
return nil, &ParseError{f, "bad NAPTR Regexp", l}, ""
return &ParseError{"", "bad NAPTR Regexp", l}
}
} else if l.value == zQuote {
rr.Regexp = ""
} else {
return nil, &ParseError{f, "bad NAPTR Regexp", l}, ""
return &ParseError{"", "bad NAPTR Regexp", l}
}
// After quote no space??
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Replacement = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad NAPTR Replacement", l}, ""
return &ParseError{"", "bad NAPTR Replacement", l}
}
rr.Replacement = name
return rr, nil, ""
return slurpRemainder(c)
}
func setTALINK(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(TALINK)
rr.Hdr = h
l := <-c
rr.PreviousName = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *TALINK) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
previousName, previousNameOk := toAbsoluteName(l.token, o)
if l.err || !previousNameOk {
return nil, &ParseError{f, "bad TALINK PreviousName", l}, ""
return &ParseError{"", "bad TALINK PreviousName", l}
}
rr.PreviousName = previousName
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
rr.NextName = l.token
nextName, nextNameOk := toAbsoluteName(l.token, o)
if l.err || !nextNameOk {
return nil, &ParseError{f, "bad TALINK NextName", l}, ""
return &ParseError{"", "bad TALINK NextName", l}
}
rr.NextName = nextName
return rr, nil, ""
return slurpRemainder(c)
}
func setLOC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(LOC)
rr.Hdr = h
func (rr *LOC) parse(c *zlexer, o string) *ParseError {
// Non zero defaults for LOC record, see RFC 1876, Section 3.
rr.HorizPre = 165 // 10000
rr.VertPre = 162 // 10
rr.Size = 18 // 1
rr.Size = 0x12 // 1e2 cm (1m)
rr.HorizPre = 0x16 // 1e6 cm (10000m)
rr.VertPre = 0x13 // 1e3 cm (10m)
ok := false
// North
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 32)
if e != nil || l.err {
return nil, &ParseError{f, "bad LOC Latitude", l}, ""
if e != nil || l.err || i > 90 {
return &ParseError{"", "bad LOC Latitude", l}
}
rr.Latitude = 1000 * 60 * 60 * uint32(i)
<-c // zBlank
c.Next() // zBlank
// Either number, 'N' or 'S'
l = <-c
l, _ = c.Next()
if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok {
goto East
}
i, e = strconv.ParseUint(l.token, 10, 32)
if e != nil || l.err {
return nil, &ParseError{f, "bad LOC Latitude minutes", l}, ""
if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
return &ParseError{"", "bad LOC Latitude minutes", l}
} else {
rr.Latitude += 1000 * 60 * uint32(i)
}
rr.Latitude += 1000 * 60 * uint32(i)
<-c // zBlank
l = <-c
if i, e := strconv.ParseFloat(l.token, 32); e != nil || l.err {
return nil, &ParseError{f, "bad LOC Latitude seconds", l}, ""
c.Next() // zBlank
l, _ = c.Next()
if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
return &ParseError{"", "bad LOC Latitude seconds", l}
} else {
rr.Latitude += uint32(1000 * i)
}
<-c // zBlank
c.Next() // zBlank
// Either number, 'N' or 'S'
l = <-c
l, _ = c.Next()
if rr.Latitude, ok = locCheckNorth(l.token, rr.Latitude); ok {
goto East
}
// If still alive, flag an error
return nil, &ParseError{f, "bad LOC Latitude North/South", l}, ""
return &ParseError{"", "bad LOC Latitude North/South", l}
East:
// East
<-c // zBlank
l = <-c
if i, e := strconv.ParseUint(l.token, 10, 32); e != nil || l.err {
return nil, &ParseError{f, "bad LOC Longitude", l}, ""
c.Next() // zBlank
l, _ = c.Next()
if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 180 {
return &ParseError{"", "bad LOC Longitude", l}
} else {
rr.Longitude = 1000 * 60 * 60 * uint32(i)
}
<-c // zBlank
c.Next() // zBlank
// Either number, 'E' or 'W'
l = <-c
l, _ = c.Next()
if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok {
goto Altitude
}
if i, e := strconv.ParseUint(l.token, 10, 32); e != nil || l.err {
return nil, &ParseError{f, "bad LOC Longitude minutes", l}, ""
if i, err := strconv.ParseUint(l.token, 10, 32); err != nil || l.err || i > 59 {
return &ParseError{"", "bad LOC Longitude minutes", l}
} else {
rr.Longitude += 1000 * 60 * uint32(i)
}
<-c // zBlank
l = <-c
if i, e := strconv.ParseFloat(l.token, 32); e != nil || l.err {
return nil, &ParseError{f, "bad LOC Longitude seconds", l}, ""
c.Next() // zBlank
l, _ = c.Next()
if i, err := strconv.ParseFloat(l.token, 64); err != nil || l.err || i < 0 || i >= 60 {
return &ParseError{"", "bad LOC Longitude seconds", l}
} else {
rr.Longitude += uint32(1000 * i)
}
<-c // zBlank
c.Next() // zBlank
// Either number, 'E' or 'W'
l = <-c
l, _ = c.Next()
if rr.Longitude, ok = locCheckEast(l.token, rr.Longitude); ok {
goto Altitude
}
// If still alive, flag an error
return nil, &ParseError{f, "bad LOC Longitude East/West", l}, ""
return &ParseError{"", "bad LOC Longitude East/West", l}
Altitude:
<-c // zBlank
l = <-c
if l.length == 0 || l.err {
return nil, &ParseError{f, "bad LOC Altitude", l}, ""
c.Next() // zBlank
l, _ = c.Next()
if l.token == "" || l.err {
return &ParseError{"", "bad LOC Altitude", l}
}
if l.token[len(l.token)-1] == 'M' || l.token[len(l.token)-1] == 'm' {
l.token = l.token[0 : len(l.token)-1]
}
if i, e := strconv.ParseFloat(l.token, 32); e != nil {
return nil, &ParseError{f, "bad LOC Altitude", l}, ""
if i, err := strconv.ParseFloat(l.token, 64); err != nil {
return &ParseError{"", "bad LOC Altitude", l}
} else {
rr.Altitude = uint32(i*100.0 + 10000000.0 + 0.5)
}
// And now optionally the other values
l = <-c
l, _ = c.Next()
count := 0
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zString:
switch count {
case 0: // Size
e, m, ok := stringToCm(l.token)
exp, m, ok := stringToCm(l.token)
if !ok {
return nil, &ParseError{f, "bad LOC Size", l}, ""
return &ParseError{"", "bad LOC Size", l}
}
rr.Size = e&0x0f | m<<4&0xf0
rr.Size = exp&0x0f | m<<4&0xf0
case 1: // HorizPre
e, m, ok := stringToCm(l.token)
exp, m, ok := stringToCm(l.token)
if !ok {
return nil, &ParseError{f, "bad LOC HorizPre", l}, ""
return &ParseError{"", "bad LOC HorizPre", l}
}
rr.HorizPre = e&0x0f | m<<4&0xf0
rr.HorizPre = exp&0x0f | m<<4&0xf0
case 2: // VertPre
e, m, ok := stringToCm(l.token)
exp, m, ok := stringToCm(l.token)
if !ok {
return nil, &ParseError{f, "bad LOC VertPre", l}, ""
return &ParseError{"", "bad LOC VertPre", l}
}
rr.VertPre = e&0x0f | m<<4&0xf0
rr.VertPre = exp&0x0f | m<<4&0xf0
}
count++
case zBlank:
// Ok
default:
return nil, &ParseError{f, "bad LOC Size, HorizPre or VertPre", l}, ""
return &ParseError{"", "bad LOC Size, HorizPre or VertPre", l}
}
l = <-c
l, _ = c.Next()
}
return rr, nil, ""
return nil
}
func setHIP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(HIP)
rr.Hdr = h
func (rr *HIP) parse(c *zlexer, o string) *ParseError {
// HitLength is not represented
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad HIP PublicKeyAlgorithm", l}, ""
return &ParseError{"", "bad HIP PublicKeyAlgorithm", l}
}
rr.PublicKeyAlgorithm = uint8(i)
<-c // zBlank
l = <-c // zString
if l.length == 0 || l.err {
return nil, &ParseError{f, "bad HIP Hit", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
if l.token == "" || l.err {
return &ParseError{"", "bad HIP Hit", l}
}
rr.Hit = l.token // This can not contain spaces, see RFC 5205 Section 6.
rr.HitLength = uint8(len(rr.Hit)) / 2
<-c // zBlank
l = <-c // zString
if l.length == 0 || l.err {
return nil, &ParseError{f, "bad HIP PublicKey", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
if l.token == "" || l.err {
return &ParseError{"", "bad HIP PublicKey", l}
}
rr.PublicKey = l.token // This cannot contain spaces
rr.PublicKeyLength = uint16(base64.StdEncoding.DecodedLen(len(rr.PublicKey)))
decodedPK, decodedPKerr := base64.StdEncoding.DecodeString(rr.PublicKey)
if decodedPKerr != nil {
return &ParseError{"", "bad HIP PublicKey", l}
}
rr.PublicKeyLength = uint16(len(decodedPK))
// RendezvousServers (if any)
l = <-c
l, _ = c.Next()
var xs []string
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zString:
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad HIP RendezvousServers", l}, ""
return &ParseError{"", "bad HIP RendezvousServers", l}
}
xs = append(xs, name)
case zBlank:
// Ok
default:
return nil, &ParseError{f, "bad HIP RendezvousServers", l}, ""
return &ParseError{"", "bad HIP RendezvousServers", l}
}
l = <-c
l, _ = c.Next()
}
rr.RendezvousServers = xs
return rr, nil, l.comment
return nil
}
func setCERT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(CERT)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *CERT) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
if v, ok := StringToCertType[l.token]; ok {
rr.Type = v
} else if i, e := strconv.ParseUint(l.token, 10, 16); e != nil {
return nil, &ParseError{f, "bad CERT Type", l}, ""
} else if i, err := strconv.ParseUint(l.token, 10, 16); err != nil {
return &ParseError{"", "bad CERT Type", l}
} else {
rr.Type = uint16(i)
}
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad CERT KeyTag", l}, ""
return &ParseError{"", "bad CERT KeyTag", l}
}
rr.KeyTag = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
if v, ok := StringToAlgorithm[l.token]; ok {
rr.Algorithm = v
} else if i, e := strconv.ParseUint(l.token, 10, 8); e != nil {
return nil, &ParseError{f, "bad CERT Algorithm", l}, ""
} else if i, err := strconv.ParseUint(l.token, 10, 8); err != nil {
return &ParseError{"", "bad CERT Algorithm", l}
} else {
rr.Algorithm = uint8(i)
}
s, e1, c1 := endingToString(c, "bad CERT Certificate", f)
s, e1 := endingToString(c, "bad CERT Certificate")
if e1 != nil {
return nil, e1, c1
return e1
}
rr.Certificate = s
return rr, nil, c1
return nil
}
func setOPENPGPKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(OPENPGPKEY)
rr.Hdr = h
s, e, c1 := endingToString(c, "bad OPENPGPKEY PublicKey", f)
func (rr *OPENPGPKEY) parse(c *zlexer, o string) *ParseError {
s, e := endingToString(c, "bad OPENPGPKEY PublicKey")
if e != nil {
return nil, e, c1
return e
}
rr.PublicKey = s
return rr, nil, c1
return nil
}
func setCSYNC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(CSYNC)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *CSYNC) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
j, e := strconv.ParseUint(l.token, 10, 32)
if e != nil {
// Serial must be a number
return nil, &ParseError{f, "bad CSYNC serial", l}, ""
return &ParseError{"", "bad CSYNC serial", l}
}
rr.Serial = uint32(j)
<-c // zBlank
c.Next() // zBlank
l = <-c
j, e = strconv.ParseUint(l.token, 10, 16)
if e != nil {
l, _ = c.Next()
j, e1 := strconv.ParseUint(l.token, 10, 16)
if e1 != nil {
// Serial must be a number
return nil, &ParseError{f, "bad CSYNC flags", l}, ""
return &ParseError{"", "bad CSYNC flags", l}
}
rr.Flags = uint16(j)
......@@ -1054,146 +829,158 @@ func setCSYNC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
k uint16
ok bool
)
l = <-c
l, _ = c.Next()
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zBlank:
// Ok
case zString:
if k, ok = StringToType[l.tokenUpper]; !ok {
if k, ok = typeToInt(l.tokenUpper); !ok {
return nil, &ParseError{f, "bad CSYNC TypeBitMap", l}, ""
tokenUpper := strings.ToUpper(l.token)
if k, ok = StringToType[tokenUpper]; !ok {
if k, ok = typeToInt(l.token); !ok {
return &ParseError{"", "bad CSYNC TypeBitMap", l}
}
}
rr.TypeBitMap = append(rr.TypeBitMap, k)
default:
return nil, &ParseError{f, "bad CSYNC TypeBitMap", l}, ""
return &ParseError{"", "bad CSYNC TypeBitMap", l}
}
l = <-c
l, _ = c.Next()
}
return rr, nil, l.comment
return nil
}
func setSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setRRSIG(h, c, o, f)
if r != nil {
return &SIG{*r.(*RRSIG)}, e, s
func (rr *ZONEMD) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 32)
if e != nil || l.err {
return &ParseError{"", "bad ZONEMD Serial", l}
}
return nil, e, s
}
rr.Serial = uint32(i)
func setRRSIG(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RRSIG)
rr.Hdr = h
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad ZONEMD Scheme", l}
}
rr.Scheme = uint8(i)
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
c.Next() // zBlank
l, _ = c.Next()
i, err := strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return &ParseError{"", "bad ZONEMD Hash Algorithm", l}
}
rr.Hash = uint8(i)
s, e2 := endingToString(c, "bad ZONEMD Digest")
if e2 != nil {
return e2
}
rr.Digest = s
return nil
}
func (rr *SIG) parse(c *zlexer, o string) *ParseError { return rr.RRSIG.parse(c, o) }
if t, ok := StringToType[l.tokenUpper]; !ok {
if strings.HasPrefix(l.tokenUpper, "TYPE") {
t, ok = typeToInt(l.tokenUpper)
func (rr *RRSIG) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
tokenUpper := strings.ToUpper(l.token)
if t, ok := StringToType[tokenUpper]; !ok {
if strings.HasPrefix(tokenUpper, "TYPE") {
t, ok = typeToInt(l.token)
if !ok {
return nil, &ParseError{f, "bad RRSIG Typecovered", l}, ""
return &ParseError{"", "bad RRSIG Typecovered", l}
}
rr.TypeCovered = t
} else {
return nil, &ParseError{f, "bad RRSIG Typecovered", l}, ""
return &ParseError{"", "bad RRSIG Typecovered", l}
}
} else {
rr.TypeCovered = t
}
<-c // zBlank
l = <-c
i, err := strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return nil, &ParseError{f, "bad RRSIG Algorithm", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return &ParseError{"", "bad RRSIG Algorithm", l}
}
rr.Algorithm = uint8(i)
<-c // zBlank
l = <-c
i, err = strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return nil, &ParseError{f, "bad RRSIG Labels", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad RRSIG Labels", l}
}
rr.Labels = uint8(i)
<-c // zBlank
l = <-c
i, err = strconv.ParseUint(l.token, 10, 32)
if err != nil || l.err {
return nil, &ParseError{f, "bad RRSIG OrigTtl", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e2 := strconv.ParseUint(l.token, 10, 32)
if e2 != nil || l.err {
return &ParseError{"", "bad RRSIG OrigTtl", l}
}
rr.OrigTtl = uint32(i)
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
if i, err := StringToTime(l.token); err != nil {
// Try to see if all numeric and use it as epoch
if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
// TODO(miek): error out on > MAX_UINT32, same below
if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
rr.Expiration = uint32(i)
} else {
return nil, &ParseError{f, "bad RRSIG Expiration", l}, ""
return &ParseError{"", "bad RRSIG Expiration", l}
}
} else {
rr.Expiration = i
}
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
if i, err := StringToTime(l.token); err != nil {
if i, err := strconv.ParseInt(l.token, 10, 64); err == nil {
if i, err := strconv.ParseUint(l.token, 10, 32); err == nil {
rr.Inception = uint32(i)
} else {
return nil, &ParseError{f, "bad RRSIG Inception", l}, ""
return &ParseError{"", "bad RRSIG Inception", l}
}
} else {
rr.Inception = i
}
<-c // zBlank
l = <-c
i, err = strconv.ParseUint(l.token, 10, 16)
if err != nil || l.err {
return nil, &ParseError{f, "bad RRSIG KeyTag", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e3 := strconv.ParseUint(l.token, 10, 16)
if e3 != nil || l.err {
return &ParseError{"", "bad RRSIG KeyTag", l}
}
rr.KeyTag = uint16(i)
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
rr.SignerName = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad RRSIG SignerName", l}, ""
return &ParseError{"", "bad RRSIG SignerName", l}
}
rr.SignerName = name
s, e, c1 := endingToString(c, "bad RRSIG Signature", f)
if e != nil {
return nil, e, c1
s, e4 := endingToString(c, "bad RRSIG Signature")
if e4 != nil {
return e4
}
rr.Signature = s
return rr, nil, c1
return nil
}
func setNSEC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NSEC)
rr.Hdr = h
l := <-c
rr.NextDomain = l.token
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *NSEC) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad NSEC NextDomain", l}, ""
return &ParseError{"", "bad NSEC NextDomain", l}
}
rr.NextDomain = name
......@@ -1202,68 +989,62 @@ func setNSEC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
k uint16
ok bool
)
l = <-c
l, _ = c.Next()
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zBlank:
// Ok
case zString:
if k, ok = StringToType[l.tokenUpper]; !ok {
if k, ok = typeToInt(l.tokenUpper); !ok {
return nil, &ParseError{f, "bad NSEC TypeBitMap", l}, ""
tokenUpper := strings.ToUpper(l.token)
if k, ok = StringToType[tokenUpper]; !ok {
if k, ok = typeToInt(l.token); !ok {
return &ParseError{"", "bad NSEC TypeBitMap", l}
}
}
rr.TypeBitMap = append(rr.TypeBitMap, k)
default:
return nil, &ParseError{f, "bad NSEC TypeBitMap", l}, ""
return &ParseError{"", "bad NSEC TypeBitMap", l}
}
l = <-c
l, _ = c.Next()
}
return rr, nil, l.comment
return nil
}
func setNSEC3(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NSEC3)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *NSEC3) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3 Hash", l}, ""
return &ParseError{"", "bad NSEC3 Hash", l}
}
rr.Hash = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3 Flags", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad NSEC3 Flags", l}
}
rr.Flags = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3 Iterations", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e2 := strconv.ParseUint(l.token, 10, 16)
if e2 != nil || l.err {
return &ParseError{"", "bad NSEC3 Iterations", l}
}
rr.Iterations = uint16(i)
<-c
l = <-c
if len(l.token) == 0 || l.err {
return nil, &ParseError{f, "bad NSEC3 Salt", l}, ""
c.Next()
l, _ = c.Next()
if l.token == "" || l.err {
return &ParseError{"", "bad NSEC3 Salt", l}
}
if l.token != "-" {
rr.SaltLength = uint8(len(l.token)) / 2
rr.Salt = l.token
}
<-c
l = <-c
if len(l.token) == 0 || l.err {
return nil, &ParseError{f, "bad NSEC3 NextDomain", l}, ""
c.Next()
l, _ = c.Next()
if l.token == "" || l.err {
return &ParseError{"", "bad NSEC3 NextDomain", l}
}
rr.HashLength = 20 // Fix for NSEC3 (sha1 160 bits)
rr.NextDomain = l.token
......@@ -1273,74 +1054,61 @@ func setNSEC3(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
k uint16
ok bool
)
l = <-c
l, _ = c.Next()
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zBlank:
// Ok
case zString:
if k, ok = StringToType[l.tokenUpper]; !ok {
if k, ok = typeToInt(l.tokenUpper); !ok {
return nil, &ParseError{f, "bad NSEC3 TypeBitMap", l}, ""
tokenUpper := strings.ToUpper(l.token)
if k, ok = StringToType[tokenUpper]; !ok {
if k, ok = typeToInt(l.token); !ok {
return &ParseError{"", "bad NSEC3 TypeBitMap", l}
}
}
rr.TypeBitMap = append(rr.TypeBitMap, k)
default:
return nil, &ParseError{f, "bad NSEC3 TypeBitMap", l}, ""
return &ParseError{"", "bad NSEC3 TypeBitMap", l}
}
l = <-c
l, _ = c.Next()
}
return rr, nil, l.comment
return nil
}
func setNSEC3PARAM(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NSEC3PARAM)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *NSEC3PARAM) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3PARAM Hash", l}, ""
return &ParseError{"", "bad NSEC3PARAM Hash", l}
}
rr.Hash = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3PARAM Flags", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad NSEC3PARAM Flags", l}
}
rr.Flags = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad NSEC3PARAM Iterations", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e2 := strconv.ParseUint(l.token, 10, 16)
if e2 != nil || l.err {
return &ParseError{"", "bad NSEC3PARAM Iterations", l}
}
rr.Iterations = uint16(i)
<-c
l = <-c
c.Next()
l, _ = c.Next()
if l.token != "-" {
rr.SaltLength = uint8(len(l.token))
rr.SaltLength = uint8(len(l.token) / 2)
rr.Salt = l.token
}
return rr, nil, ""
return slurpRemainder(c)
}
func setEUI48(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(EUI48)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
if l.length != 17 || l.err {
return nil, &ParseError{f, "bad EUI48 Address", l}, ""
func (rr *EUI48) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
if len(l.token) != 17 || l.err {
return &ParseError{"", "bad EUI48 Address", l}
}
addr := make([]byte, 12)
dash := 0
......@@ -1349,7 +1117,7 @@ func setEUI48(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
addr[i+1] = l.token[i+1+dash]
dash++
if l.token[i+1+dash] != '-' {
return nil, &ParseError{f, "bad EUI48 Address", l}, ""
return &ParseError{"", "bad EUI48 Address", l}
}
}
addr[10] = l.token[15]
......@@ -1357,23 +1125,16 @@ func setEUI48(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
i, e := strconv.ParseUint(string(addr), 16, 48)
if e != nil {
return nil, &ParseError{f, "bad EUI48 Address", l}, ""
return &ParseError{"", "bad EUI48 Address", l}
}
rr.Address = i
return rr, nil, ""
return slurpRemainder(c)
}
func setEUI64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(EUI64)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
if l.length != 23 || l.err {
return nil, &ParseError{f, "bad EUI64 Address", l}, ""
func (rr *EUI64) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
if len(l.token) != 23 || l.err {
return &ParseError{"", "bad EUI64 Address", l}
}
addr := make([]byte, 16)
dash := 0
......@@ -1382,7 +1143,7 @@ func setEUI64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
addr[i+1] = l.token[i+1+dash]
dash++
if l.token[i+1+dash] != '-' {
return nil, &ParseError{f, "bad EUI64 Address", l}, ""
return &ParseError{"", "bad EUI64 Address", l}
}
}
addr[14] = l.token[21]
......@@ -1390,814 +1151,628 @@ func setEUI64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
i, e := strconv.ParseUint(string(addr), 16, 64)
if e != nil {
return nil, &ParseError{f, "bad EUI68 Address", l}, ""
return &ParseError{"", "bad EUI68 Address", l}
}
rr.Address = uint64(i)
return rr, nil, ""
rr.Address = i
return slurpRemainder(c)
}
func setSSHFP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(SSHFP)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *SSHFP) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad SSHFP Algorithm", l}, ""
return &ParseError{"", "bad SSHFP Algorithm", l}
}
rr.Algorithm = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad SSHFP Type", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad SSHFP Type", l}
}
rr.Type = uint8(i)
<-c // zBlank
s, e1, c1 := endingToString(c, "bad SSHFP Fingerprint", f)
if e1 != nil {
return nil, e1, c1
c.Next() // zBlank
s, e2 := endingToString(c, "bad SSHFP Fingerprint")
if e2 != nil {
return e2
}
rr.FingerPrint = s
return rr, nil, ""
return nil
}
func setDNSKEYs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(DNSKEY)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *DNSKEY) parseDNSKEY(c *zlexer, o, typ string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad " + typ + " Flags", l}, ""
return &ParseError{"", "bad " + typ + " Flags", l}
}
rr.Flags = uint16(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad " + typ + " Protocol", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad " + typ + " Protocol", l}
}
rr.Protocol = uint8(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e2 := strconv.ParseUint(l.token, 10, 8)
if e2 != nil || l.err {
return &ParseError{"", "bad " + typ + " Algorithm", l}
}
rr.Algorithm = uint8(i)
s, e1, c1 := endingToString(c, "bad "+typ+" PublicKey", f)
if e1 != nil {
return nil, e1, c1
s, e3 := endingToString(c, "bad "+typ+" PublicKey")
if e3 != nil {
return e3
}
rr.PublicKey = s
return rr, nil, c1
}
func setKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "KEY")
if r != nil {
return &KEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
return nil
}
func setDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "DNSKEY")
return r, e, s
}
func setCDNSKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDNSKEYs(h, c, o, f, "CDNSKEY")
if r != nil {
return &CDNSKEY{*r.(*DNSKEY)}, e, s
}
return nil, e, s
}
func setRKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RKEY)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *DNSKEY) parse(c *zlexer, o string) *ParseError { return rr.parseDNSKEY(c, o, "DNSKEY") }
func (rr *KEY) parse(c *zlexer, o string) *ParseError { return rr.parseDNSKEY(c, o, "KEY") }
func (rr *CDNSKEY) parse(c *zlexer, o string) *ParseError { return rr.parseDNSKEY(c, o, "CDNSKEY") }
func (rr *DS) parse(c *zlexer, o string) *ParseError { return rr.parseDS(c, o, "DS") }
func (rr *DLV) parse(c *zlexer, o string) *ParseError { return rr.parseDS(c, o, "DLV") }
func (rr *CDS) parse(c *zlexer, o string) *ParseError { return rr.parseDS(c, o, "CDS") }
func (rr *RKEY) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad RKEY Flags", l}, ""
return &ParseError{"", "bad RKEY Flags", l}
}
rr.Flags = uint16(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad RKEY Protocol", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad RKEY Protocol", l}
}
rr.Protocol = uint8(i)
<-c // zBlank
l = <-c // zString
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad RKEY Algorithm", l}, ""
c.Next() // zBlank
l, _ = c.Next() // zString
i, e2 := strconv.ParseUint(l.token, 10, 8)
if e2 != nil || l.err {
return &ParseError{"", "bad RKEY Algorithm", l}
}
rr.Algorithm = uint8(i)
s, e1, c1 := endingToString(c, "bad RKEY PublicKey", f)
if e1 != nil {
return nil, e1, c1
s, e3 := endingToString(c, "bad RKEY PublicKey")
if e3 != nil {
return e3
}
rr.PublicKey = s
return rr, nil, c1
return nil
}
func setEID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(EID)
rr.Hdr = h
s, e, c1 := endingToString(c, "bad EID Endpoint", f)
func (rr *EID) parse(c *zlexer, o string) *ParseError {
s, e := endingToString(c, "bad EID Endpoint")
if e != nil {
return nil, e, c1
return e
}
rr.Endpoint = s
return rr, nil, c1
return nil
}
func setNIMLOC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NIMLOC)
rr.Hdr = h
s, e, c1 := endingToString(c, "bad NIMLOC Locator", f)
func (rr *NIMLOC) parse(c *zlexer, o string) *ParseError {
s, e := endingToString(c, "bad NIMLOC Locator")
if e != nil {
return nil, e, c1
return e
}
rr.Locator = s
return rr, nil, c1
return nil
}
func setGPOS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(GPOS)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *GPOS) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
_, e := strconv.ParseFloat(l.token, 64)
if e != nil || l.err {
return nil, &ParseError{f, "bad GPOS Longitude", l}, ""
return &ParseError{"", "bad GPOS Longitude", l}
}
rr.Longitude = l.token
<-c // zBlank
l = <-c
_, e = strconv.ParseFloat(l.token, 64)
if e != nil || l.err {
return nil, &ParseError{f, "bad GPOS Latitude", l}, ""
c.Next() // zBlank
l, _ = c.Next()
_, e1 := strconv.ParseFloat(l.token, 64)
if e1 != nil || l.err {
return &ParseError{"", "bad GPOS Latitude", l}
}
rr.Latitude = l.token
<-c // zBlank
l = <-c
_, e = strconv.ParseFloat(l.token, 64)
if e != nil || l.err {
return nil, &ParseError{f, "bad GPOS Altitude", l}, ""
c.Next() // zBlank
l, _ = c.Next()
_, e2 := strconv.ParseFloat(l.token, 64)
if e2 != nil || l.err {
return &ParseError{"", "bad GPOS Altitude", l}
}
rr.Altitude = l.token
return rr, nil, ""
return slurpRemainder(c)
}
func setDSs(h RR_Header, c chan lex, o, f, typ string) (RR, *ParseError, string) {
rr := new(DS)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *DS) parseDS(c *zlexer, o, typ string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad " + typ + " KeyTag", l}, ""
return &ParseError{"", "bad " + typ + " KeyTag", l}
}
rr.KeyTag = uint16(i)
<-c // zBlank
l = <-c
if i, e = strconv.ParseUint(l.token, 10, 8); e != nil {
i, ok := StringToAlgorithm[l.tokenUpper]
c.Next() // zBlank
l, _ = c.Next()
if i, err := strconv.ParseUint(l.token, 10, 8); err != nil {
tokenUpper := strings.ToUpper(l.token)
i, ok := StringToAlgorithm[tokenUpper]
if !ok || l.err {
return nil, &ParseError{f, "bad " + typ + " Algorithm", l}, ""
return &ParseError{"", "bad " + typ + " Algorithm", l}
}
rr.Algorithm = i
} else {
rr.Algorithm = uint8(i)
}
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad " + typ + " DigestType", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad " + typ + " DigestType", l}
}
rr.DigestType = uint8(i)
s, e1, c1 := endingToString(c, "bad "+typ+" Digest", f)
if e1 != nil {
return nil, e1, c1
s, e2 := endingToString(c, "bad "+typ+" Digest")
if e2 != nil {
return e2
}
rr.Digest = s
return rr, nil, c1
}
func setDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DS")
return r, e, s
}
func setDLV(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "DLV")
if r != nil {
return &DLV{*r.(*DS)}, e, s
}
return nil, e, s
}
func setCDS(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
r, e, s := setDSs(h, c, o, f, "CDS")
if r != nil {
return &CDS{*r.(*DS)}, e, s
}
return nil, e, s
return nil
}
func setTA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(TA)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *TA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad TA KeyTag", l}, ""
return &ParseError{"", "bad TA KeyTag", l}
}
rr.KeyTag = uint16(i)
<-c // zBlank
l = <-c
if i, e := strconv.ParseUint(l.token, 10, 8); e != nil {
i, ok := StringToAlgorithm[l.tokenUpper]
c.Next() // zBlank
l, _ = c.Next()
if i, err := strconv.ParseUint(l.token, 10, 8); err != nil {
tokenUpper := strings.ToUpper(l.token)
i, ok := StringToAlgorithm[tokenUpper]
if !ok || l.err {
return nil, &ParseError{f, "bad TA Algorithm", l}, ""
return &ParseError{"", "bad TA Algorithm", l}
}
rr.Algorithm = i
} else {
rr.Algorithm = uint8(i)
}
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad TA DigestType", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad TA DigestType", l}
}
rr.DigestType = uint8(i)
s, e, c1 := endingToString(c, "bad TA Digest", f)
if e != nil {
return nil, e.(*ParseError), c1
s, e2 := endingToString(c, "bad TA Digest")
if e2 != nil {
return e2
}
rr.Digest = s
return rr, nil, c1
return nil
}
func setTLSA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(TLSA)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *TLSA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad TLSA Usage", l}, ""
return &ParseError{"", "bad TLSA Usage", l}
}
rr.Usage = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad TLSA Selector", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad TLSA Selector", l}
}
rr.Selector = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad TLSA MatchingType", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e2 := strconv.ParseUint(l.token, 10, 8)
if e2 != nil || l.err {
return &ParseError{"", "bad TLSA MatchingType", l}
}
rr.MatchingType = uint8(i)
// So this needs be e2 (i.e. different than e), because...??t
s, e2, c1 := endingToString(c, "bad TLSA Certificate", f)
if e2 != nil {
return nil, e2, c1
s, e3 := endingToString(c, "bad TLSA Certificate")
if e3 != nil {
return e3
}
rr.Certificate = s
return rr, nil, c1
return nil
}
func setSMIMEA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(SMIMEA)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
func (rr *SMIMEA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad SMIMEA Usage", l}, ""
return &ParseError{"", "bad SMIMEA Usage", l}
}
rr.Usage = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad SMIMEA Selector", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad SMIMEA Selector", l}
}
rr.Selector = uint8(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return nil, &ParseError{f, "bad SMIMEA MatchingType", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e2 := strconv.ParseUint(l.token, 10, 8)
if e2 != nil || l.err {
return &ParseError{"", "bad SMIMEA MatchingType", l}
}
rr.MatchingType = uint8(i)
// So this needs be e2 (i.e. different than e), because...??t
s, e2, c1 := endingToString(c, "bad SMIMEA Certificate", f)
if e2 != nil {
return nil, e2, c1
s, e3 := endingToString(c, "bad SMIMEA Certificate")
if e3 != nil {
return e3
}
rr.Certificate = s
return rr, nil, c1
return nil
}
func setRFC3597(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(RFC3597)
rr.Hdr = h
l := <-c
func (rr *RFC3597) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
if l.token != "\\#" {
return nil, &ParseError{f, "bad RFC3597 Rdata", l}, ""
return &ParseError{"", "bad RFC3597 Rdata", l}
}
<-c // zBlank
l = <-c
rdlength, e := strconv.Atoi(l.token)
c.Next() // zBlank
l, _ = c.Next()
rdlength, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad RFC3597 Rdata ", l}, ""
return &ParseError{"", "bad RFC3597 Rdata ", l}
}
s, e1, c1 := endingToString(c, "bad RFC3597 Rdata", f)
s, e1 := endingToString(c, "bad RFC3597 Rdata")
if e1 != nil {
return nil, e1, c1
return e1
}
if rdlength*2 != len(s) {
return nil, &ParseError{f, "bad RFC3597 Rdata", l}, ""
if int(rdlength)*2 != len(s) {
return &ParseError{"", "bad RFC3597 Rdata", l}
}
rr.Rdata = s
return rr, nil, c1
return nil
}
func setSPF(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(SPF)
rr.Hdr = h
s, e, c1 := endingToTxtSlice(c, "bad SPF Txt", f)
func (rr *SPF) parse(c *zlexer, o string) *ParseError {
s, e := endingToTxtSlice(c, "bad SPF Txt")
if e != nil {
return nil, e, ""
return e
}
rr.Txt = s
return rr, nil, c1
return nil
}
func setAVC(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(AVC)
rr.Hdr = h
s, e, c1 := endingToTxtSlice(c, "bad AVC Txt", f)
func (rr *AVC) parse(c *zlexer, o string) *ParseError {
s, e := endingToTxtSlice(c, "bad AVC Txt")
if e != nil {
return nil, e, ""
return e
}
rr.Txt = s
return rr, nil, c1
return nil
}
func setTXT(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(TXT)
rr.Hdr = h
func (rr *TXT) parse(c *zlexer, o string) *ParseError {
// no zBlank reading here, because all this rdata is TXT
s, e, c1 := endingToTxtSlice(c, "bad TXT Txt", f)
s, e := endingToTxtSlice(c, "bad TXT Txt")
if e != nil {
return nil, e, ""
return e
}
rr.Txt = s
return rr, nil, c1
return nil
}
// identical to setTXT
func setNINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NINFO)
rr.Hdr = h
s, e, c1 := endingToTxtSlice(c, "bad NINFO ZSData", f)
func (rr *NINFO) parse(c *zlexer, o string) *ParseError {
s, e := endingToTxtSlice(c, "bad NINFO ZSData")
if e != nil {
return nil, e, ""
return e
}
rr.ZSData = s
return rr, nil, c1
return nil
}
func setURI(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(URI)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *URI) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad URI Priority", l}, ""
return &ParseError{"", "bad URI Priority", l}
}
rr.Priority = uint16(i)
<-c // zBlank
l = <-c
i, e = strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad URI Weight", l}, ""
c.Next() // zBlank
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 16)
if e1 != nil || l.err {
return &ParseError{"", "bad URI Weight", l}
}
rr.Weight = uint16(i)
<-c // zBlank
s, err, c1 := endingToTxtSlice(c, "bad URI Target", f)
if err != nil {
return nil, err, ""
c.Next() // zBlank
s, e2 := endingToTxtSlice(c, "bad URI Target")
if e2 != nil {
return e2
}
if len(s) != 1 {
return nil, &ParseError{f, "bad URI Target", l}, ""
return &ParseError{"", "bad URI Target", l}
}
rr.Target = s[0]
return rr, nil, c1
return nil
}
func setDHCID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
func (rr *DHCID) parse(c *zlexer, o string) *ParseError {
// awesome record to parse!
rr := new(DHCID)
rr.Hdr = h
s, e, c1 := endingToString(c, "bad DHCID Digest", f)
s, e := endingToString(c, "bad DHCID Digest")
if e != nil {
return nil, e, c1
return e
}
rr.Digest = s
return rr, nil, c1
return nil
}
func setNID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(NID)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *NID) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad NID Preference", l}, ""
return &ParseError{"", "bad NID Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
u, err := stringToNodeID(l)
if err != nil || l.err {
return nil, err, ""
c.Next() // zBlank
l, _ = c.Next() // zString
u, e1 := stringToNodeID(l)
if e1 != nil || l.err {
return e1
}
rr.NodeID = u
return rr, nil, ""
return slurpRemainder(c)
}
func setL32(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(L32)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *L32) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad L32 Preference", l}, ""
return &ParseError{"", "bad L32 Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Locator32 = net.ParseIP(l.token)
if rr.Locator32 == nil || l.err {
return nil, &ParseError{f, "bad L32 Locator", l}, ""
return &ParseError{"", "bad L32 Locator", l}
}
return rr, nil, ""
return slurpRemainder(c)
}
func setLP(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(LP)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *LP) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad LP Preference", l}, ""
return &ParseError{"", "bad LP Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Fqdn = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return nil, &ParseError{f, "bad LP Fqdn", l}, ""
return &ParseError{"", "bad LP Fqdn", l}
}
rr.Fqdn = name
return rr, nil, ""
return slurpRemainder(c)
}
func setL64(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(L64)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *L64) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad L64 Preference", l}, ""
return &ParseError{"", "bad L64 Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
u, err := stringToNodeID(l)
if err != nil || l.err {
return nil, err, ""
c.Next() // zBlank
l, _ = c.Next() // zString
u, e1 := stringToNodeID(l)
if e1 != nil || l.err {
return e1
}
rr.Locator64 = u
return rr, nil, ""
return slurpRemainder(c)
}
func setUID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(UID)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *UID) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 32)
if e != nil || l.err {
return nil, &ParseError{f, "bad UID Uid", l}, ""
return &ParseError{"", "bad UID Uid", l}
}
rr.Uid = uint32(i)
return rr, nil, ""
return slurpRemainder(c)
}
func setGID(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(GID)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *GID) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 32)
if e != nil || l.err {
return nil, &ParseError{f, "bad GID Gid", l}, ""
return &ParseError{"", "bad GID Gid", l}
}
rr.Gid = uint32(i)
return rr, nil, ""
return slurpRemainder(c)
}
func setUINFO(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(UINFO)
rr.Hdr = h
s, e, c1 := endingToTxtSlice(c, "bad UINFO Uinfo", f)
func (rr *UINFO) parse(c *zlexer, o string) *ParseError {
s, e := endingToTxtSlice(c, "bad UINFO Uinfo")
if e != nil {
return nil, e, c1
return e
}
if ln := len(s); ln == 0 {
return rr, nil, c1
return nil
}
rr.Uinfo = s[0] // silently discard anything after the first character-string
return rr, nil, c1
return nil
}
func setPX(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(PX)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, ""
}
func (rr *PX) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return nil, &ParseError{f, "bad PX Preference", l}, ""
return &ParseError{"", "bad PX Preference", l}
}
rr.Preference = uint16(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Map822 = l.token
map822, map822Ok := toAbsoluteName(l.token, o)
if l.err || !map822Ok {
return nil, &ParseError{f, "bad PX Map822", l}, ""
return &ParseError{"", "bad PX Map822", l}
}
rr.Map822 = map822
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Mapx400 = l.token
mapx400, mapx400Ok := toAbsoluteName(l.token, o)
if l.err || !mapx400Ok {
return nil, &ParseError{f, "bad PX Mapx400", l}, ""
return &ParseError{"", "bad PX Mapx400", l}
}
rr.Mapx400 = mapx400
return rr, nil, ""
return slurpRemainder(c)
}
func setCAA(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(CAA)
rr.Hdr = h
l := <-c
if l.length == 0 { // dynamic update rr.
return rr, nil, l.comment
}
i, err := strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return nil, &ParseError{f, "bad CAA Flag", l}, ""
func (rr *CAA) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return &ParseError{"", "bad CAA Flag", l}
}
rr.Flag = uint8(i)
<-c // zBlank
l = <-c // zString
c.Next() // zBlank
l, _ = c.Next() // zString
if l.value != zString {
return nil, &ParseError{f, "bad CAA Tag", l}, ""
return &ParseError{"", "bad CAA Tag", l}
}
rr.Tag = l.token
<-c // zBlank
s, e, c1 := endingToTxtSlice(c, "bad CAA Value", f)
if e != nil {
return nil, e, ""
c.Next() // zBlank
s, e1 := endingToTxtSlice(c, "bad CAA Value")
if e1 != nil {
return e1
}
if len(s) != 1 {
return nil, &ParseError{f, "bad CAA Value", l}, ""
return &ParseError{"", "bad CAA Value", l}
}
rr.Value = s[0]
return rr, nil, c1
return nil
}
func setTKEY(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) {
rr := new(TKEY)
rr.Hdr = h
l := <-c
func (rr *TKEY) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
// Algorithm
if l.value != zString {
return nil, &ParseError{f, "bad TKEY algorithm", l}, ""
return &ParseError{"", "bad TKEY algorithm", l}
}
rr.Algorithm = l.token
<-c // zBlank
c.Next() // zBlank
// Get the key length and key values
l = <-c
i, err := strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return nil, &ParseError{f, "bad TKEY key length", l}, ""
l, _ = c.Next()
i, e := strconv.ParseUint(l.token, 10, 8)
if e != nil || l.err {
return &ParseError{"", "bad TKEY key length", l}
}
rr.KeySize = uint16(i)
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
if l.value != zString {
return nil, &ParseError{f, "bad TKEY key", l}, ""
return &ParseError{"", "bad TKEY key", l}
}
rr.Key = l.token
<-c // zBlank
c.Next() // zBlank
// Get the otherdata length and string data
l = <-c
i, err = strconv.ParseUint(l.token, 10, 8)
if err != nil || l.err {
return nil, &ParseError{f, "bad TKEY otherdata length", l}, ""
l, _ = c.Next()
i, e1 := strconv.ParseUint(l.token, 10, 8)
if e1 != nil || l.err {
return &ParseError{"", "bad TKEY otherdata length", l}
}
rr.OtherLen = uint16(i)
<-c // zBlank
l = <-c
c.Next() // zBlank
l, _ = c.Next()
if l.value != zString {
return nil, &ParseError{f, "bad TKEY otherday", l}, ""
return &ParseError{"", "bad TKEY otherday", l}
}
rr.OtherData = l.token
return nil
}
func (rr *APL) parse(c *zlexer, o string) *ParseError {
var prefixes []APLPrefix
for {
l, _ := c.Next()
if l.value == zNewline || l.value == zEOF {
break
}
if l.value == zBlank && prefixes != nil {
continue
}
if l.value != zString {
return &ParseError{"", "unexpected APL field", l}
}
// Expected format: [!]afi:address/prefix
colon := strings.IndexByte(l.token, ':')
if colon == -1 {
return &ParseError{"", "missing colon in APL field", l}
}
family, cidr := l.token[:colon], l.token[colon+1:]
var negation bool
if family != "" && family[0] == '!' {
negation = true
family = family[1:]
}
afi, e := strconv.ParseUint(family, 10, 16)
if e != nil {
return &ParseError{"", "failed to parse APL family: " + e.Error(), l}
}
var addrLen int
switch afi {
case 1:
addrLen = net.IPv4len
case 2:
addrLen = net.IPv6len
default:
return &ParseError{"", "unrecognized APL family", l}
}
ip, subnet, e1 := net.ParseCIDR(cidr)
if e1 != nil {
return &ParseError{"", "failed to parse APL address: " + e1.Error(), l}
}
if !ip.Equal(subnet.IP) {
return &ParseError{"", "extra bits in APL address", l}
}
if len(subnet.IP) != addrLen {
return &ParseError{"", "address mismatch with the APL family", l}
}
prefixes = append(prefixes, APLPrefix{
Negation: negation,
Network: *subnet,
})
}
return rr, nil, ""
}
var typeToparserFunc = map[uint16]parserFunc{
TypeAAAA: {setAAAA, false},
TypeAFSDB: {setAFSDB, false},
TypeA: {setA, false},
TypeCAA: {setCAA, true},
TypeCDS: {setCDS, true},
TypeCDNSKEY: {setCDNSKEY, true},
TypeCERT: {setCERT, true},
TypeCNAME: {setCNAME, false},
TypeCSYNC: {setCSYNC, true},
TypeDHCID: {setDHCID, true},
TypeDLV: {setDLV, true},
TypeDNAME: {setDNAME, false},
TypeKEY: {setKEY, true},
TypeDNSKEY: {setDNSKEY, true},
TypeDS: {setDS, true},
TypeEID: {setEID, true},
TypeEUI48: {setEUI48, false},
TypeEUI64: {setEUI64, false},
TypeGID: {setGID, false},
TypeGPOS: {setGPOS, false},
TypeHINFO: {setHINFO, true},
TypeHIP: {setHIP, true},
TypeKX: {setKX, false},
TypeL32: {setL32, false},
TypeL64: {setL64, false},
TypeLOC: {setLOC, true},
TypeLP: {setLP, false},
TypeMB: {setMB, false},
TypeMD: {setMD, false},
TypeMF: {setMF, false},
TypeMG: {setMG, false},
TypeMINFO: {setMINFO, false},
TypeMR: {setMR, false},
TypeMX: {setMX, false},
TypeNAPTR: {setNAPTR, false},
TypeNID: {setNID, false},
TypeNIMLOC: {setNIMLOC, true},
TypeNINFO: {setNINFO, true},
TypeNSAPPTR: {setNSAPPTR, false},
TypeNSEC3PARAM: {setNSEC3PARAM, false},
TypeNSEC3: {setNSEC3, true},
TypeNSEC: {setNSEC, true},
TypeNS: {setNS, false},
TypeOPENPGPKEY: {setOPENPGPKEY, true},
TypePTR: {setPTR, false},
TypePX: {setPX, false},
TypeSIG: {setSIG, true},
TypeRKEY: {setRKEY, true},
TypeRP: {setRP, false},
TypeRRSIG: {setRRSIG, true},
TypeRT: {setRT, false},
TypeSMIMEA: {setSMIMEA, true},
TypeSOA: {setSOA, false},
TypeSPF: {setSPF, true},
TypeAVC: {setAVC, true},
TypeSRV: {setSRV, false},
TypeSSHFP: {setSSHFP, true},
TypeTALINK: {setTALINK, false},
TypeTA: {setTA, true},
TypeTLSA: {setTLSA, true},
TypeTXT: {setTXT, true},
TypeUID: {setUID, false},
TypeUINFO: {setUINFO, true},
TypeURI: {setURI, true},
TypeX25: {setX25, false},
TypeTKEY: {setTKEY, true},
rr.Prefixes = prefixes
return nil
}
package dns
// Implement a simple scanner, return a byte stream from an io reader.
import (
"bufio"
"context"
"io"
"text/scanner"
)
type scan struct {
src *bufio.Reader
position scanner.Position
eof bool // Have we just seen a eof
ctx context.Context
}
func scanInit(r io.Reader) (*scan, context.CancelFunc) {
s := new(scan)
s.src = bufio.NewReader(r)
s.position.Line = 1
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
return s, cancel
}
// tokenText returns the next byte from the input
func (s *scan) tokenText() (byte, error) {
c, err := s.src.ReadByte()
if err != nil {
return c, err
}
select {
case <-s.ctx.Done():
return c, context.Canceled
default:
break
}
// delay the newline handling until the next token is delivered,
// fixes off-by-one errors when reporting a parse error.
if s.eof == true {
s.position.Line++
s.position.Column = 0
s.eof = false
}
if c == '\n' {
s.eof = true
return c, nil
}
s.position.Column++
return c, nil
}
package dns
import (
"sync"
)
// ServeMux is an DNS request multiplexer. It matches the zone name of
// each incoming request against a list of registered patterns add calls
// the handler for the pattern that most closely matches the zone name.
//
// ServeMux is DNSSEC aware, meaning that queries for the DS record are
// redirected to the parent zone (if that is also registered), otherwise
// the child gets the query.
//
// ServeMux is also safe for concurrent access from multiple goroutines.
//
// The zero ServeMux is empty and ready for use.
type ServeMux struct {
z map[string]Handler
m sync.RWMutex
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux {
return new(ServeMux)
}
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
func (mux *ServeMux) match(q string, t uint16) Handler {
mux.m.RLock()
defer mux.m.RUnlock()
if mux.z == nil {
return nil
}
q = CanonicalName(q)
var handler Handler
for off, end := 0, false; !end; off, end = NextLabel(q, off) {
if h, ok := mux.z[q[off:]]; ok {
if t != TypeDS {
return h
}
// Continue for DS to see if we have a parent too, if so delegate to the parent
handler = h
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := mux.z["."]; ok {
return h
}
return handler
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
if mux.z == nil {
mux.z = make(map[string]Handler)
}
mux.z[CanonicalName(pattern)] = handler
mux.m.Unlock()
}
// HandleFunc adds a handler function to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}
// HandleRemove deregisters the handler specific for pattern from the ServeMux.
func (mux *ServeMux) HandleRemove(pattern string) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
delete(mux.z, CanonicalName(pattern))
mux.m.Unlock()
}
// ServeDNS dispatches the request to the handler whose pattern most
// closely matches the request message.
//
// ServeDNS is DNSSEC aware, meaning that queries for the DS record
// are redirected to the parent zone (if that is also registered),
// otherwise the child gets the query.
//
// If no handler is found, or there is no question, a standard REFUSED
// message is returned
func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
var h Handler
if len(req.Question) >= 1 { // allow more than one question
h = mux.match(req.Question[0].Name, req.Question[0].Qtype)
}
if h != nil {
h.ServeDNS(w, req)
} else {
handleRefused(w, req)
}
}
// Handle registers the handler with the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleRemove deregisters the handle with the given pattern
// in the DefaultServeMux.
func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
......@@ -3,37 +3,40 @@
package dns
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
// Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
// The maximum number of idle workers.
//
// This controls the maximum number of workers that are allowed to stay
// idle waiting for incoming requests before being torn down.
//
// If this limit is reached, the server will just keep spawning new
// workers (goroutines) for each incoming request. In this case, each
// worker will only be used for a single request.
const maxIdleWorkersCount = 10000
// The maximum length of time a worker may idle for before being destroyed.
const idleWorkerTimeout = 10 * time.Second
// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancelation of network operations.
var aLongTimeAgo = time.Unix(1, 0)
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
}
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
// ServeDNS calls f(w, r).
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
}
// A ResponseWriter interface is used by an DNS handler to
// construct an DNS response.
type ResponseWriter interface {
......@@ -56,49 +59,35 @@ type ResponseWriter interface {
Hijack()
}
// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
// when available.
type ConnectionStater interface {
ConnectionState() *tls.ConnectionState
}
type response struct {
msg []byte
closed bool // connection has been closed
hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp *net.UDPConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
writer Writer // writer to output the raw DNS bits
}
// ServeMux is an DNS request multiplexer. It matches the
// zone name of each incoming request against a list of
// registered patterns add calls the handler for the pattern
// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
// that queries for the DS record are redirected to the parent zone (if that
// is also registered), otherwise the child gets the query.
// ServeMux is also safe for concurrent access from multiple goroutines.
type ServeMux struct {
z map[string]Handler
m *sync.RWMutex
tsigProvider TsigProvider
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
// ServeDNS calls f(w, r).
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
func handleRefused(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeRefused)
w.WriteMsg(m)
}
// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
// Deprecated: This function is going away.
func HandleFailed(w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
......@@ -106,8 +95,6 @@ func HandleFailed(w ResponseWriter, r *Msg) {
w.WriteMsg(m)
}
func failedHandler() Handler { return HandlerFunc(HandleFailed) }
// ListenAndServe Starts a server on address and network specified Invoke handler
// for incoming queries.
func ListenAndServe(addr string, network string, handler Handler) error {
......@@ -146,99 +133,6 @@ func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
return server.ActivateAndServe()
}
func (mux *ServeMux) match(q string, t uint16) Handler {
mux.m.RLock()
defer mux.m.RUnlock()
var handler Handler
b := make([]byte, len(q)) // worst case, one label of length q
off := 0
end := false
for {
l := len(q[off:])
for i := 0; i < l; i++ {
b[i] = q[off+i]
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] |= 'a' - 'A'
}
}
if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
if t != TypeDS {
return h
}
// Continue for DS to see if we have a parent too, if so delegeate to the parent
handler = h
}
off, end = NextLabel(q, off)
if end {
break
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := mux.z["."]; ok {
return h
}
return handler
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
mux.z[Fqdn(pattern)] = handler
mux.m.Unlock()
}
// HandleFunc adds a handler function to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}
// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
func (mux *ServeMux) HandleRemove(pattern string) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
delete(mux.z, Fqdn(pattern))
mux.m.Unlock()
}
// ServeDNS dispatches the request to the handler whose
// pattern most closely matches the request message. If DefaultServeMux
// is used the correct thing for DS queries is done: a possible parent
// is sought.
// If no handler is found a standard SERVFAIL message is returned
// If the request message does not have exactly one question in the
// question section a SERVFAIL is returned, unlesss Unsafe is true.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
var h Handler
if len(request.Question) < 1 { // allow more than one question
h = failedHandler()
} else {
if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
h = failedHandler()
}
}
h.ServeDNS(w, request)
}
// Handle registers the handler with the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleRemove deregisters the handle with the given pattern
// in the DefaultServeMux.
func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
// Writer writes raw DNS messages; each call to Write should send an entire message.
type Writer interface {
io.Writer
......@@ -254,22 +148,40 @@ type Reader interface {
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
}
// defaultReader is an adapter for the Server struct that implements the Reader interface
// using the readTCP and readUDP func of the embedded Server.
// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
type PacketConnReader interface {
Reader
// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
// alter connection properties, for example the read-deadline.
ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
}
// defaultReader is an adapter for the Server struct that implements the Reader and
// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
// of the embedded Server.
type defaultReader struct {
*Server
}
func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
var _ PacketConnReader = defaultReader{}
func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
}
func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
return dr.readUDP(conn, timeout)
}
func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
return dr.readPacketConn(conn, timeout)
}
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
// Implementations should never return a nil Reader.
// Readers should also implement the optional PacketConnReader interface.
// PacketConnReader is required to use a generic net.PacketConn.
type DecorateReader func(Reader) Reader
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
......@@ -299,11 +211,10 @@ type Server struct {
WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration
// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
TsigProvider TsigProvider
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specifically not check if the query has the QR bit not set.
Unsafe bool
// If NotifyStartedFunc is set it is called once the server has started listening.
NotifyStartedFunc func()
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
......@@ -312,14 +223,31 @@ type Server struct {
DecorateWriter DecorateWriter
// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
MaxTCPQueries int
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
// It is only supported on go1.11+ and when using ListenAndServe.
ReusePort bool
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
// By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
// Shutdown handling
lock sync.RWMutex
started bool
lock sync.RWMutex
started bool
shutdown chan struct{}
conns map[net.Conn]struct{}
// A pool for UDP message buffers.
udpPool sync.Pool
}
func (srv *Server) tsigProvider() TsigProvider {
if srv.TsigProvider != nil {
return srv.TsigProvider
}
if srv.TsigSecret != nil {
return tsigSecretProvider(srv.TsigSecret)
}
return nil
}
func (srv *Server) isStarted() bool {
......@@ -329,49 +257,27 @@ func (srv *Server) isStarted() bool {
return started
}
func (srv *Server) worker(w *response) {
srv.serve(w)
for {
count := atomic.LoadInt32(&srv.workersCount)
if count > maxIdleWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
break
}
func makeUDPBuffer(size int) func() interface{} {
return func() interface{} {
return make([]byte, size)
}
}
defer atomic.AddInt32(&srv.workersCount, -1)
func (srv *Server) init() {
srv.shutdown = make(chan struct{})
srv.conns = make(map[net.Conn]struct{})
inUse := false
timeout := time.NewTimer(idleWorkerTimeout)
defer timeout.Stop()
LOOP:
for {
select {
case w, ok := <-srv.queue:
if !ok {
break LOOP
}
inUse = true
srv.serve(w)
case <-timeout.C:
if !inUse {
break LOOP
}
inUse = false
timeout.Reset(idleWorkerTimeout)
}
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
}
func (srv *Server) spawnWorker(w *response) {
select {
case srv.queue <- w:
default:
go srv.worker(w)
if srv.MsgAcceptFunc == nil {
srv.MsgAcceptFunc = DefaultMsgAcceptFunc
}
if srv.Handler == nil {
srv.Handler = DefaultServeMux
}
srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
}
func unlockOnce(l sync.Locker) func() {
......@@ -393,18 +299,12 @@ func (srv *Server) ListenAndServe() error {
if addr == "" {
addr = ":domain"
}
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
srv.queue = make(chan *response)
defer close(srv.queue)
srv.init()
switch srv.Net {
case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr)
if err != nil {
return err
}
l, err := net.ListenTCP(srv.Net, a)
l, err := listenTCP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
......@@ -413,37 +313,33 @@ func (srv *Server) ListenAndServe() error {
unlock()
return srv.serveTCP(l)
case "tcp-tls", "tcp4-tls", "tcp6-tls":
network := "tcp"
if srv.Net == "tcp4-tls" {
network = "tcp4"
} else if srv.Net == "tcp6-tls" {
network = "tcp6"
if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
return errors.New("dns: neither Certificates nor GetCertificate set in Config")
}
l, err := tls.Listen(network, addr, srv.TLSConfig)
network := strings.TrimSuffix(srv.Net, "-tls")
l, err := listenTCP(network, addr, srv.ReusePort)
if err != nil {
return err
}
l = tls.NewListener(l, srv.TLSConfig)
srv.Listener = l
srv.started = true
unlock()
return srv.serveTCP(l)
case "udp", "udp4", "udp6":
a, err := net.ResolveUDPAddr(srv.Net, addr)
l, err := listenUDP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
l, err := net.ListenUDP(srv.Net, a)
if err != nil {
return err
}
if e := setUDPSocketOptions(l); e != nil {
u := l.(*net.UDPConn)
if e := setUDPSocketOptions(u); e != nil {
u.Close()
return e
}
srv.PacketConn = l
srv.started = true
unlock()
return srv.serveUDP(l)
return srv.serveUDP(u)
}
return &Error{err: "bad network"}
}
......@@ -459,29 +355,24 @@ func (srv *Server) ActivateAndServe() error {
return &Error{err: "server already started"}
}
pConn := srv.PacketConn
l := srv.Listener
srv.queue = make(chan *response)
defer close(srv.queue)
if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
srv.init()
if srv.PacketConn != nil {
// Check PacketConn interface's type is valid and value
// is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
if e := setUDPSocketOptions(t); e != nil {
return e
}
srv.started = true
unlock()
return srv.serveUDP(t)
}
srv.started = true
unlock()
return srv.serveUDP(srv.PacketConn)
}
if l != nil {
if srv.Listener != nil {
srv.started = true
unlock()
return srv.serveTCP(l)
return srv.serveTCP(srv.Listener)
}
return &Error{err: "bad listeners"}
}
......@@ -489,30 +380,63 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
// ActivateAndServe will return.
func (srv *Server) Shutdown() error {
return srv.ShutdownContext(context.Background())
}
// ShutdownContext shuts down a server. After a call to ShutdownContext,
// ListenAndServe and ActivateAndServe will return.
//
// A context.Context may be passed to limit how long to wait for connections
// to terminate.
func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock()
if !srv.started {
srv.lock.Unlock()
return &Error{err: "server not started"}
}
srv.started = false
srv.lock.Unlock()
if srv.PacketConn != nil {
srv.PacketConn.Close()
srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
if srv.Listener != nil {
srv.Listener.Close()
}
return nil
for rw := range srv.conns {
rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
srv.lock.Unlock()
if testShutdownNotify != nil {
testShutdownNotify.Broadcast()
}
var ctxErr error
select {
case <-srv.shutdown:
case <-ctx.Done():
ctxErr = ctx.Err()
}
if srv.PacketConn != nil {
srv.PacketConn.Close()
}
return ctxErr
}
var testShutdownNotify *sync.Cond
// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout
return srv.ReadTimeout
}
return rtimeout
return dnsTimeout
}
// serveTCP starts a TCP listener for the server.
......@@ -523,78 +447,109 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
for {
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(srv.shutdown)
}()
for srv.isStarted() {
rw, err := l.Accept()
if !srv.isStarted() {
return nil
}
if err != nil {
if !srv.isStarted() {
return nil
}
if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
continue
}
return err
}
srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
srv.lock.Lock()
// Track the connection to allow unblocking reads on shutdown.
srv.conns[rw] = struct{}{}
srv.lock.Unlock()
wg.Add(1)
go srv.serveTCPConn(&wg, rw)
}
return nil
}
// serveUDP starts a UDP listener for the server.
func (srv *Server) serveUDP(l *net.UDPConn) error {
func (srv *Server) serveUDP(l net.PacketConn) error {
defer l.Close()
reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
lUDP, isUDP := l.(*net.UDPConn)
readerPC, canPacketConn := reader.(PacketConnReader)
if !isUDP && !canPacketConn {
return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
}
if srv.NotifyStartedFunc != nil {
srv.NotifyStartedFunc()
}
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(srv.shutdown)
}()
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
m, s, err := reader.ReadUDP(l, rtimeout)
if !srv.isStarted() {
return nil
for srv.isStarted() {
var (
m []byte
sPC net.Addr
sUDP *SessionUDP
err error
)
if isUDP {
m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
} else {
m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
}
if err != nil {
if !srv.isStarted() {
return nil
}
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
continue
}
return err
}
if len(m) < headerSize {
if cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
continue
}
srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
wg.Add(1)
go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
}
return nil
}
func (srv *Server) serve(w *response) {
// Serve a new TCP connection.
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
return
}
reader := Reader(&defaultReader{srv})
reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
defer func() {
if !w.hijacked {
w.Close()
}
}()
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
......@@ -607,15 +562,14 @@ func (srv *Server) serve(w *response) {
limit = maxTCPQueries
}
for q := 0; q < limit || limit == -1; q++ {
var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
m, err := reader.ReadTCP(w.tcp, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(w)
if w.tcp == nil {
srv.serveDNS(m, w)
if w.closed {
break // Close() was called
}
if w.hijacked {
......@@ -625,94 +579,156 @@ func (srv *Server) serve(w *response) {
// idle timeout.
timeout = idleTimeout
}
if !w.hijacked {
w.Close()
}
srv.lock.Lock()
delete(srv.conns, w.tcp)
srv.lock.Unlock()
wg.Done()
}
func (srv *Server) serveDNS(w *response) {
req := new(Msg)
err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
srv.serveDNS(m, w)
wg.Done()
}
func (srv *Server) serveDNS(m []byte, w *response) {
dh, off, err := unpackMsgHdr(m, 0)
if err != nil {
// Let client hang, they are sending crap; any reply can be used to amplify.
return
}
if !srv.Unsafe && req.Response {
req := new(Msg)
req.setHdr(dh)
switch action := srv.MsgAcceptFunc(dh); action {
case MsgAccept:
if req.unpack(dh, m, off) == nil {
break
}
fallthrough
case MsgReject, MsgRejectNotImplemented:
opcode := req.Opcode
req.SetRcodeFormatError(req)
req.Zero = false
if action == MsgRejectNotImplemented {
req.Opcode = opcode
req.Rcode = RcodeNotImplemented
}
// Are we allowed to delete any OPT records here?
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
fallthrough
case MsgIgnore:
if w.udp != nil && cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
return
}
w.tsigStatus = nil
if w.tsigSecret != nil {
if w.tsigProvider != nil {
if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(w.msg, secret, "", false)
} else {
w.tsigStatus = ErrSecret
}
w.tsigStatus = TsigVerifyWithProvider(m, w.tsigProvider, "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
w.tsigRequestMAC = t.MAC
}
}
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
if w.udp != nil && cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
handler.ServeDNS(w, req) // Writes back to the client
srv.Handler.ServeDNS(w, req) // Writes back to the client
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
l := make([]byte, 2)
n, err := conn.Read(l)
if err != nil || n != 2 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
}
length := binary.BigEndian.Uint16(l)
if length == 0 {
return nil, ErrShortRead
// If we race with ShutdownContext, the read deadline may
// have been set in the distant past to unblock the read
// below. We must not override it, otherwise we may block
// ShutdownContext.
srv.lock.RLock()
if srv.started {
conn.SetReadDeadline(time.Now().Add(timeout))
}
m := make([]byte, int(length))
n, err = conn.Read(m[:int(length)])
if err != nil || n == 0 {
if err != nil {
return nil, err
}
return nil, ErrShortRead
srv.lock.RUnlock()
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
return nil, err
}
i := n
for i < int(length) {
j, err := conn.Read(m[i:int(length)])
if err != nil {
return nil, err
}
i += j
m := make([]byte, length)
if _, err := io.ReadFull(conn, m); err != nil {
return nil, err
}
n = i
m = m[:n]
return m, nil
}
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize)
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
return m, s, nil
}
func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte)
n, addr, err := conn.ReadFrom(m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
return m, addr, nil
}
// WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
return &Error{err: "WriteMsg called after Close"}
}
var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
data, w.tsigRequestMAC, err = TsigGenerateWithProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
......@@ -730,42 +746,54 @@ func (w *response) WriteMsg(m *Msg) (err error) {
// Write implements the ResponseWriter.Write method.
func (w *response) Write(m []byte) (int, error) {
if w.closed {
return 0, &Error{err: "Write called after Close"}
}
switch {
case w.udp != nil:
n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
return n, err
case w.tcp != nil:
lm := len(m)
if lm < 2 {
return 0, io.ErrShortBuffer
if u, ok := w.udp.(*net.UDPConn); ok {
return WriteToSessionUDP(u, m, w.udpSession)
}
if lm > MaxMsgSize {
return w.udp.WriteTo(m, w.pcSession)
case w.tcp != nil:
if len(m) > MaxMsgSize {
return 0, &Error{err: "message too large"}
}
l := make([]byte, 2, 2+lm)
binary.BigEndian.PutUint16(l, uint16(lm))
m = append(l, m...)
n, err := io.Copy(w.tcp, bytes.NewReader(m))
return int(n), err
msg := make([]byte, 2+len(m))
binary.BigEndian.PutUint16(msg, uint16(len(m)))
copy(msg[2:], m)
return w.tcp.Write(msg)
default:
panic("dns: internal error: udp and tcp both nil")
}
panic("not reached")
}
// LocalAddr implements the ResponseWriter.LocalAddr method.
func (w *response) LocalAddr() net.Addr {
if w.tcp != nil {
switch {
case w.udp != nil:
return w.udp.LocalAddr()
case w.tcp != nil:
return w.tcp.LocalAddr()
default:
panic("dns: internal error: udp and tcp both nil")
}
return w.udp.LocalAddr()
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
func (w *response) RemoteAddr() net.Addr {
if w.tcp != nil {
switch {
case w.udpSession != nil:
return w.udpSession.RemoteAddr()
case w.pcSession != nil:
return w.pcSession
case w.tcp != nil:
return w.tcp.RemoteAddr()
default:
panic("dns: internal error: udpSession, pcSession and tcp are all nil")
}
return w.udpSession.RemoteAddr()
}
// TsigStatus implements the ResponseWriter.TsigStatus method.
......@@ -779,11 +807,30 @@ func (w *response) Hijack() { w.hijacked = true }
// Close implements the ResponseWriter.Close method
func (w *response) Close() error {
// Can't close the udp conn, as that is actually the listener.
if w.tcp != nil {
e := w.tcp.Close()
w.tcp = nil
return e
if w.closed {
return &Error{err: "connection already closed"}
}
w.closed = true
switch {
case w.udp != nil:
// Can't close the udp conn, as that is actually the listener.
return nil
case w.tcp != nil:
return w.tcp.Close()
default:
panic("dns: internal error: udp and tcp both nil")
}
}
// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
func (w *response) ConnectionState() *tls.ConnectionState {
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
if v, ok := w.tcp.(tlsConnectionStater); ok {
t := v.ConnectionState()
return &t
}
return nil
}
......@@ -2,8 +2,8 @@ package dns
import (
"crypto"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"encoding/binary"
"math/big"
......@@ -18,18 +18,14 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
if k == nil {
return nil, ErrPrivKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
return nil, ErrKey
}
rr.Header().Rrtype = TypeSIG
rr.Header().Class = ClassANY
rr.Header().Ttl = 0
rr.Header().Name = "."
rr.OrigTtl = 0
rr.TypeCovered = 0
rr.Labels = 0
buf := make([]byte, m.Len()+rr.len())
rr.Hdr = RR_Header{Name: ".", Rrtype: TypeSIG, Class: ClassANY, Ttl: 0}
rr.OrigTtl, rr.TypeCovered, rr.Labels = 0, 0, 0
buf := make([]byte, m.Len()+Len(rr))
mbuf, err := m.PackBuffer(buf)
if err != nil {
return nil, err
......@@ -43,18 +39,17 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
}
buf = buf[:off:cap(buf)]
hash, ok := AlgorithmToHash[rr.Algorithm]
if !ok {
return nil, ErrAlg
h, cryptohash, err := hashFromAlgorithm(rr.Algorithm)
if err != nil {
return nil, err
}
hasher := hash.New()
// Write SIG rdata
hasher.Write(buf[len(mbuf)+1+2+2+4+2:])
h.Write(buf[len(mbuf)+1+2+2+4+2:])
// Write message
hasher.Write(buf[:len(mbuf)])
h.Write(buf[:len(mbuf)])
signature, err := sign(k, hasher.Sum(nil), hash, rr.Algorithm)
signature, err := sign(k, h.Sum(nil), cryptohash, rr.Algorithm)
if err != nil {
return nil, err
}
......@@ -83,32 +78,21 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
if k == nil {
return ErrKey
}
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
if rr.KeyTag == 0 || rr.SignerName == "" || rr.Algorithm == 0 {
return ErrKey
}
var hash crypto.Hash
switch rr.Algorithm {
case DSA, RSASHA1:
hash = crypto.SHA1
case RSASHA256, ECDSAP256SHA256:
hash = crypto.SHA256
case ECDSAP384SHA384:
hash = crypto.SHA384
case RSASHA512:
hash = crypto.SHA512
default:
return ErrAlg
}
hasher := hash.New()
h, cryptohash, err := hashFromAlgorithm(rr.Algorithm)
if err != nil {
return err
}
buflen := len(buf)
qdc := binary.BigEndian.Uint16(buf[4:])
anc := binary.BigEndian.Uint16(buf[6:])
auc := binary.BigEndian.Uint16(buf[8:])
adc := binary.BigEndian.Uint16(buf[10:])
offset := 12
var err error
offset := headerSize
for i := uint16(0); i < qdc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset)
if err != nil {
......@@ -127,8 +111,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
if offset+1 >= buflen {
continue
}
var rdlen uint16
rdlen = binary.BigEndian.Uint16(buf[offset:])
rdlen := binary.BigEndian.Uint16(buf[offset:])
offset += 2
offset += int(rdlen)
}
......@@ -168,51 +151,44 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
}
// If key has come from the DNS name compression might
// have mangled the case of the name
if strings.ToLower(signername) != strings.ToLower(k.Header().Name) {
if !strings.EqualFold(signername, k.Header().Name) {
return &Error{err: "signer name doesn't match key name"}
}
sigend := offset
hasher.Write(buf[sigstart:sigend])
hasher.Write(buf[:10])
hasher.Write([]byte{
h.Write(buf[sigstart:sigend])
h.Write(buf[:10])
h.Write([]byte{
byte((adc - 1) << 8),
byte(adc - 1),
})
hasher.Write(buf[12:bodyend])
h.Write(buf[12:bodyend])
hashed := hasher.Sum(nil)
hashed := h.Sum(nil)
sig := buf[sigend:]
switch k.Algorithm {
case DSA:
pk := k.publicKeyDSA()
sig = sig[1:]
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
if pk != nil {
if dsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
case RSASHA1, RSASHA256, RSASHA512:
pk := k.publicKeyRSA()
if pk != nil {
return rsa.VerifyPKCS1v15(pk, hash, hashed, sig)
return rsa.VerifyPKCS1v15(pk, cryptohash, hashed, sig)
}
case ECDSAP256SHA256, ECDSAP384SHA384:
pk := k.publicKeyECDSA()
r := big.NewInt(0)
r.SetBytes(sig[:len(sig)/2])
s := big.NewInt(0)
s.SetBytes(sig[len(sig)/2:])
r := new(big.Int).SetBytes(sig[:len(sig)/2])
s := new(big.Int).SetBytes(sig[len(sig)/2:])
if pk != nil {
if ecdsa.Verify(pk, hashed, r, s) {
return nil
}
return ErrSig
}
case ED25519:
pk := k.publicKeyED25519()
if pk != nil {
if ed25519.Verify(pk, hashed, sig) {
return nil
}
return ErrSig
}
}
return ErrKeyAlg
}
......@@ -23,6 +23,8 @@ type call struct {
type singleflight struct {
sync.Mutex // protects m
m map[string]*call // lazily initialized
dontDeleteForTesting bool // this is only to be used by TestConcurrentExchanges
}
// Do executes and returns the results of the given function, making
......@@ -49,9 +51,11 @@ func (g *singleflight) Do(key string, fn func() (*Msg, time.Duration, error)) (v
c.val, c.rtt, c.err = fn()
c.wg.Done()
g.Lock()
delete(g.m, key)
g.Unlock()
if !g.dontDeleteForTesting {
g.Lock()
delete(g.m, key)
g.Unlock()
}
return c.val, c.rtt, c.err, c.dups > 0
}
......@@ -14,10 +14,7 @@ func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate)
r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil {
return err
}
return nil
return err
}
// Verify verifies a SMIMEA record against an SSL certificate. If it is OK
......
package dns
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
)
// SVCBKey is the type of the keys used in the SVCB RR.
type SVCBKey uint16
// Keys defined in draft-ietf-dnsop-svcb-https-08 Section 14.3.2.
const (
SVCB_MANDATORY SVCBKey = iota
SVCB_ALPN
SVCB_NO_DEFAULT_ALPN
SVCB_PORT
SVCB_IPV4HINT
SVCB_ECHCONFIG
SVCB_IPV6HINT
SVCB_DOHPATH // draft-ietf-add-svcb-dns-02 Section 9
svcb_RESERVED SVCBKey = 65535
)
var svcbKeyToStringMap = map[SVCBKey]string{
SVCB_MANDATORY: "mandatory",
SVCB_ALPN: "alpn",
SVCB_NO_DEFAULT_ALPN: "no-default-alpn",
SVCB_PORT: "port",
SVCB_IPV4HINT: "ipv4hint",
SVCB_ECHCONFIG: "ech",
SVCB_IPV6HINT: "ipv6hint",
SVCB_DOHPATH: "dohpath",
}
var svcbStringToKeyMap = reverseSVCBKeyMap(svcbKeyToStringMap)
func reverseSVCBKeyMap(m map[SVCBKey]string) map[string]SVCBKey {
n := make(map[string]SVCBKey, len(m))
for u, s := range m {
n[s] = u
}
return n
}
// String takes the numerical code of an SVCB key and returns its name.
// Returns an empty string for reserved keys.
// Accepts unassigned keys as well as experimental/private keys.
func (key SVCBKey) String() string {
if x := svcbKeyToStringMap[key]; x != "" {
return x
}
if key == svcb_RESERVED {
return ""
}
return "key" + strconv.FormatUint(uint64(key), 10)
}
// svcbStringToKey returns the numerical code of an SVCB key.
// Returns svcb_RESERVED for reserved/invalid keys.
// Accepts unassigned keys as well as experimental/private keys.
func svcbStringToKey(s string) SVCBKey {
if strings.HasPrefix(s, "key") {
a, err := strconv.ParseUint(s[3:], 10, 16)
// no leading zeros
// key shouldn't be registered
if err != nil || a == 65535 || s[3] == '0' || svcbKeyToStringMap[SVCBKey(a)] != "" {
return svcb_RESERVED
}
return SVCBKey(a)
}
if key, ok := svcbStringToKeyMap[s]; ok {
return key
}
return svcb_RESERVED
}
func (rr *SVCB) parse(c *zlexer, o string) *ParseError {
l, _ := c.Next()
i, e := strconv.ParseUint(l.token, 10, 16)
if e != nil || l.err {
return &ParseError{l.token, "bad SVCB priority", l}
}
rr.Priority = uint16(i)
c.Next() // zBlank
l, _ = c.Next() // zString
rr.Target = l.token
name, nameOk := toAbsoluteName(l.token, o)
if l.err || !nameOk {
return &ParseError{l.token, "bad SVCB Target", l}
}
rr.Target = name
// Values (if any)
l, _ = c.Next()
var xs []SVCBKeyValue
// Helps require whitespace between pairs.
// Prevents key1000="a"key1001=...
canHaveNextKey := true
for l.value != zNewline && l.value != zEOF {
switch l.value {
case zString:
if !canHaveNextKey {
// The key we can now read was probably meant to be
// a part of the last value.
return &ParseError{l.token, "bad SVCB value quotation", l}
}
// In key=value pairs, value does not have to be quoted unless value
// contains whitespace. And keys don't need to have values.
// Similarly, keys with an equality signs after them don't need values.
// l.token includes at least up to the first equality sign.
idx := strings.IndexByte(l.token, '=')
var key, value string
if idx < 0 {
// Key with no value and no equality sign
key = l.token
} else if idx == 0 {
return &ParseError{l.token, "bad SVCB key", l}
} else {
key, value = l.token[:idx], l.token[idx+1:]
if value == "" {
// We have a key and an equality sign. Maybe we have nothing
// after "=" or we have a double quote.
l, _ = c.Next()
if l.value == zQuote {
// Only needed when value ends with double quotes.
// Any value starting with zQuote ends with it.
canHaveNextKey = false
l, _ = c.Next()
switch l.value {
case zString:
// We have a value in double quotes.
value = l.token
l, _ = c.Next()
if l.value != zQuote {
return &ParseError{l.token, "SVCB unterminated value", l}
}
case zQuote:
// There's nothing in double quotes.
default:
return &ParseError{l.token, "bad SVCB value", l}
}
}
}
}
kv := makeSVCBKeyValue(svcbStringToKey(key))
if kv == nil {
return &ParseError{l.token, "bad SVCB key", l}
}
if err := kv.parse(value); err != nil {
return &ParseError{l.token, err.Error(), l}
}
xs = append(xs, kv)
case zQuote:
return &ParseError{l.token, "SVCB key can't contain double quotes", l}
case zBlank:
canHaveNextKey = true
default:
return &ParseError{l.token, "bad SVCB values", l}
}
l, _ = c.Next()
}
// "In AliasMode, records SHOULD NOT include any SvcParams, and recipients MUST
// ignore any SvcParams that are present."
// However, we don't check rr.Priority == 0 && len(xs) > 0 here
// It is the responsibility of the user of the library to check this.
// This is to encourage the fixing of the source of this error.
rr.Value = xs
return nil
}
// makeSVCBKeyValue returns an SVCBKeyValue struct with the key or nil for reserved keys.
func makeSVCBKeyValue(key SVCBKey) SVCBKeyValue {
switch key {
case SVCB_MANDATORY:
return new(SVCBMandatory)
case SVCB_ALPN:
return new(SVCBAlpn)
case SVCB_NO_DEFAULT_ALPN:
return new(SVCBNoDefaultAlpn)
case SVCB_PORT:
return new(SVCBPort)
case SVCB_IPV4HINT:
return new(SVCBIPv4Hint)
case SVCB_ECHCONFIG:
return new(SVCBECHConfig)
case SVCB_IPV6HINT:
return new(SVCBIPv6Hint)
case SVCB_DOHPATH:
return new(SVCBDoHPath)
case svcb_RESERVED:
return nil
default:
e := new(SVCBLocal)
e.KeyCode = key
return e
}
}
// SVCB RR. See RFC xxxx (https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-08).
//
// NOTE: The HTTPS/SVCB RFCs are in the draft stage.
// The API, including constants and types related to SVCBKeyValues, may
// change in future versions in accordance with the latest drafts.
type SVCB struct {
Hdr RR_Header
Priority uint16 // If zero, Value must be empty or discarded by the user of this library
Target string `dns:"domain-name"`
Value []SVCBKeyValue `dns:"pairs"`
}
// HTTPS RR. Everything valid for SVCB applies to HTTPS as well.
// Except that the HTTPS record is intended for use with the HTTP and HTTPS protocols.
//
// NOTE: The HTTPS/SVCB RFCs are in the draft stage.
// The API, including constants and types related to SVCBKeyValues, may
// change in future versions in accordance with the latest drafts.
type HTTPS struct {
SVCB
}
func (rr *HTTPS) String() string {
return rr.SVCB.String()
}
func (rr *HTTPS) parse(c *zlexer, o string) *ParseError {
return rr.SVCB.parse(c, o)
}
// SVCBKeyValue defines a key=value pair for the SVCB RR type.
// An SVCB RR can have multiple SVCBKeyValues appended to it.
type SVCBKeyValue interface {
Key() SVCBKey // Key returns the numerical key code.
pack() ([]byte, error) // pack returns the encoded value.
unpack([]byte) error // unpack sets the value.
String() string // String returns the string representation of the value.
parse(string) error // parse sets the value to the given string representation of the value.
copy() SVCBKeyValue // copy returns a deep-copy of the pair.
len() int // len returns the length of value in the wire format.
}
// SVCBMandatory pair adds to required keys that must be interpreted for the RR
// to be functional. If ignored, the whole RRSet must be ignored.
// "port" and "no-default-alpn" are mandatory by default if present,
// so they shouldn't be included here.
//
// It is incumbent upon the user of this library to reject the RRSet if
// or avoid constructing such an RRSet that:
// - "mandatory" is included as one of the keys of mandatory
// - no key is listed multiple times in mandatory
// - all keys listed in mandatory are present
// - escape sequences are not used in mandatory
// - mandatory, when present, lists at least one key
//
// Basic use pattern for creating a mandatory option:
//
// s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
// e := new(dns.SVCBMandatory)
// e.Code = []uint16{dns.SVCB_ALPN}
// s.Value = append(s.Value, e)
// t := new(dns.SVCBAlpn)
// t.Alpn = []string{"xmpp-client"}
// s.Value = append(s.Value, t)
type SVCBMandatory struct {
Code []SVCBKey
}
func (*SVCBMandatory) Key() SVCBKey { return SVCB_MANDATORY }
func (s *SVCBMandatory) String() string {
str := make([]string, len(s.Code))
for i, e := range s.Code {
str[i] = e.String()
}
return strings.Join(str, ",")
}
func (s *SVCBMandatory) pack() ([]byte, error) {
codes := append([]SVCBKey(nil), s.Code...)
sort.Slice(codes, func(i, j int) bool {
return codes[i] < codes[j]
})
b := make([]byte, 2*len(codes))
for i, e := range codes {
binary.BigEndian.PutUint16(b[2*i:], uint16(e))
}
return b, nil
}
func (s *SVCBMandatory) unpack(b []byte) error {
if len(b)%2 != 0 {
return errors.New("dns: svcbmandatory: value length is not a multiple of 2")
}
codes := make([]SVCBKey, 0, len(b)/2)
for i := 0; i < len(b); i += 2 {
// We assume strictly increasing order.
codes = append(codes, SVCBKey(binary.BigEndian.Uint16(b[i:])))
}
s.Code = codes
return nil
}
func (s *SVCBMandatory) parse(b string) error {
str := strings.Split(b, ",")
codes := make([]SVCBKey, 0, len(str))
for _, e := range str {
codes = append(codes, svcbStringToKey(e))
}
s.Code = codes
return nil
}
func (s *SVCBMandatory) len() int {
return 2 * len(s.Code)
}
func (s *SVCBMandatory) copy() SVCBKeyValue {
return &SVCBMandatory{
append([]SVCBKey(nil), s.Code...),
}
}
// SVCBAlpn pair is used to list supported connection protocols.
// The user of this library must ensure that at least one protocol is listed when alpn is present.
// Protocol IDs can be found at:
// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids
// Basic use pattern for creating an alpn option:
//
// h := new(dns.HTTPS)
// h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
// e := new(dns.SVCBAlpn)
// e.Alpn = []string{"h2", "http/1.1"}
// h.Value = append(h.Value, e)
type SVCBAlpn struct {
Alpn []string
}
func (*SVCBAlpn) Key() SVCBKey { return SVCB_ALPN }
func (s *SVCBAlpn) String() string {
// An ALPN value is a comma-separated list of values, each of which can be
// an arbitrary binary value. In order to allow parsing, the comma and
// backslash characters are themselves excaped.
//
// However, this escaping is done in addition to the normal escaping which
// happens in zone files, meaning that these values must be
// double-escaped. This looks terrible, so if you see a never-ending
// sequence of backslash in a zone file this may be why.
//
// https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-08#appendix-A.1
var str strings.Builder
for i, alpn := range s.Alpn {
// 4*len(alpn) is the worst case where we escape every character in the alpn as \123, plus 1 byte for the ',' separating the alpn from others
str.Grow(4*len(alpn) + 1)
if i > 0 {
str.WriteByte(',')
}
for j := 0; j < len(alpn); j++ {
e := alpn[j]
if ' ' > e || e > '~' {
str.WriteString(escapeByte(e))
continue
}
switch e {
// We escape a few characters which may confuse humans or parsers.
case '"', ';', ' ':
str.WriteByte('\\')
str.WriteByte(e)
// The comma and backslash characters themselves must be
// doubly-escaped. We use `\\` for the first backslash and
// the escaped numeric value for the other value. We especially
// don't want a comma in the output.
case ',':
str.WriteString(`\\\044`)
case '\\':
str.WriteString(`\\\092`)
default:
str.WriteByte(e)
}
}
}
return str.String()
}
func (s *SVCBAlpn) pack() ([]byte, error) {
// Liberally estimate the size of an alpn as 10 octets
b := make([]byte, 0, 10*len(s.Alpn))
for _, e := range s.Alpn {
if e == "" {
return nil, errors.New("dns: svcbalpn: empty alpn-id")
}
if len(e) > 255 {
return nil, errors.New("dns: svcbalpn: alpn-id too long")
}
b = append(b, byte(len(e)))
b = append(b, e...)
}
return b, nil
}
func (s *SVCBAlpn) unpack(b []byte) error {
// Estimate the size of the smallest alpn as 4 bytes
alpn := make([]string, 0, len(b)/4)
for i := 0; i < len(b); {
length := int(b[i])
i++
if i+length > len(b) {
return errors.New("dns: svcbalpn: alpn array overflowing")
}
alpn = append(alpn, string(b[i:i+length]))
i += length
}
s.Alpn = alpn
return nil
}
func (s *SVCBAlpn) parse(b string) error {
if len(b) == 0 {
s.Alpn = []string{}
return nil
}
alpn := []string{}
a := []byte{}
for p := 0; p < len(b); {
c, q := nextByte(b, p)
if q == 0 {
return errors.New("dns: svcbalpn: unterminated escape")
}
p += q
// If we find a comma, we have finished reading an alpn.
if c == ',' {
if len(a) == 0 {
return errors.New("dns: svcbalpn: empty protocol identifier")
}
alpn = append(alpn, string(a))
a = []byte{}
continue
}
// If it's a backslash, we need to handle a comma-separated list.
if c == '\\' {
dc, dq := nextByte(b, p)
if dq == 0 {
return errors.New("dns: svcbalpn: unterminated escape decoding comma-separated list")
}
if dc != '\\' && dc != ',' {
return errors.New("dns: svcbalpn: bad escaped character decoding comma-separated list")
}
p += dq
c = dc
}
a = append(a, c)
}
// Add the final alpn.
if len(a) == 0 {
return errors.New("dns: svcbalpn: last protocol identifier empty")
}
s.Alpn = append(alpn, string(a))
return nil
}
func (s *SVCBAlpn) len() int {
var l int
for _, e := range s.Alpn {
l += 1 + len(e)
}
return l
}
func (s *SVCBAlpn) copy() SVCBKeyValue {
return &SVCBAlpn{
append([]string(nil), s.Alpn...),
}
}
// SVCBNoDefaultAlpn pair signifies no support for default connection protocols.
// Should be used in conjunction with alpn.
// Basic use pattern for creating a no-default-alpn option:
//
// s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
// t := new(dns.SVCBAlpn)
// t.Alpn = []string{"xmpp-client"}
// s.Value = append(s.Value, t)
// e := new(dns.SVCBNoDefaultAlpn)
// s.Value = append(s.Value, e)
type SVCBNoDefaultAlpn struct{}
func (*SVCBNoDefaultAlpn) Key() SVCBKey { return SVCB_NO_DEFAULT_ALPN }
func (*SVCBNoDefaultAlpn) copy() SVCBKeyValue { return &SVCBNoDefaultAlpn{} }
func (*SVCBNoDefaultAlpn) pack() ([]byte, error) { return []byte{}, nil }
func (*SVCBNoDefaultAlpn) String() string { return "" }
func (*SVCBNoDefaultAlpn) len() int { return 0 }
func (*SVCBNoDefaultAlpn) unpack(b []byte) error {
if len(b) != 0 {
return errors.New("dns: svcbnodefaultalpn: no-default-alpn must have no value")
}
return nil
}
func (*SVCBNoDefaultAlpn) parse(b string) error {
if b != "" {
return errors.New("dns: svcbnodefaultalpn: no-default-alpn must have no value")
}
return nil
}
// SVCBPort pair defines the port for connection.
// Basic use pattern for creating a port option:
//
// s := &dns.SVCB{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}}
// e := new(dns.SVCBPort)
// e.Port = 80
// s.Value = append(s.Value, e)
type SVCBPort struct {
Port uint16
}
func (*SVCBPort) Key() SVCBKey { return SVCB_PORT }
func (*SVCBPort) len() int { return 2 }
func (s *SVCBPort) String() string { return strconv.FormatUint(uint64(s.Port), 10) }
func (s *SVCBPort) copy() SVCBKeyValue { return &SVCBPort{s.Port} }
func (s *SVCBPort) unpack(b []byte) error {
if len(b) != 2 {
return errors.New("dns: svcbport: port length is not exactly 2 octets")
}
s.Port = binary.BigEndian.Uint16(b)
return nil
}
func (s *SVCBPort) pack() ([]byte, error) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, s.Port)
return b, nil
}
func (s *SVCBPort) parse(b string) error {
port, err := strconv.ParseUint(b, 10, 16)
if err != nil {
return errors.New("dns: svcbport: port out of range")
}
s.Port = uint16(port)
return nil
}
// SVCBIPv4Hint pair suggests an IPv4 address which may be used to open connections
// if A and AAAA record responses for SVCB's Target domain haven't been received.
// In that case, optionally, A and AAAA requests can be made, after which the connection
// to the hinted IP address may be terminated and a new connection may be opened.
// Basic use pattern for creating an ipv4hint option:
//
// h := new(dns.HTTPS)
// h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
// e := new(dns.SVCBIPv4Hint)
// e.Hint = []net.IP{net.IPv4(1,1,1,1).To4()}
//
// Or
//
// e.Hint = []net.IP{net.ParseIP("1.1.1.1").To4()}
// h.Value = append(h.Value, e)
type SVCBIPv4Hint struct {
Hint []net.IP
}
func (*SVCBIPv4Hint) Key() SVCBKey { return SVCB_IPV4HINT }
func (s *SVCBIPv4Hint) len() int { return 4 * len(s.Hint) }
func (s *SVCBIPv4Hint) pack() ([]byte, error) {
b := make([]byte, 0, 4*len(s.Hint))
for _, e := range s.Hint {
x := e.To4()
if x == nil {
return nil, errors.New("dns: svcbipv4hint: expected ipv4, hint is ipv6")
}
b = append(b, x...)
}
return b, nil
}
func (s *SVCBIPv4Hint) unpack(b []byte) error {
if len(b) == 0 || len(b)%4 != 0 {
return errors.New("dns: svcbipv4hint: ipv4 address byte array length is not a multiple of 4")
}
x := make([]net.IP, 0, len(b)/4)
for i := 0; i < len(b); i += 4 {
x = append(x, net.IP(b[i:i+4]))
}
s.Hint = x
return nil
}
func (s *SVCBIPv4Hint) String() string {
str := make([]string, len(s.Hint))
for i, e := range s.Hint {
x := e.To4()
if x == nil {
return "<nil>"
}
str[i] = x.String()
}
return strings.Join(str, ",")
}
func (s *SVCBIPv4Hint) parse(b string) error {
if strings.Contains(b, ":") {
return errors.New("dns: svcbipv4hint: expected ipv4, got ipv6")
}
str := strings.Split(b, ",")
dst := make([]net.IP, len(str))
for i, e := range str {
ip := net.ParseIP(e).To4()
if ip == nil {
return errors.New("dns: svcbipv4hint: bad ip")
}
dst[i] = ip
}
s.Hint = dst
return nil
}
func (s *SVCBIPv4Hint) copy() SVCBKeyValue {
hint := make([]net.IP, len(s.Hint))
for i, ip := range s.Hint {
hint[i] = copyIP(ip)
}
return &SVCBIPv4Hint{
Hint: hint,
}
}
// SVCBECHConfig pair contains the ECHConfig structure defined in draft-ietf-tls-esni [RFC xxxx].
// Basic use pattern for creating an ech option:
//
// h := new(dns.HTTPS)
// h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
// e := new(dns.SVCBECHConfig)
// e.ECH = []byte{0xfe, 0x08, ...}
// h.Value = append(h.Value, e)
type SVCBECHConfig struct {
ECH []byte // Specifically ECHConfigList including the redundant length prefix
}
func (*SVCBECHConfig) Key() SVCBKey { return SVCB_ECHCONFIG }
func (s *SVCBECHConfig) String() string { return toBase64(s.ECH) }
func (s *SVCBECHConfig) len() int { return len(s.ECH) }
func (s *SVCBECHConfig) pack() ([]byte, error) {
return append([]byte(nil), s.ECH...), nil
}
func (s *SVCBECHConfig) copy() SVCBKeyValue {
return &SVCBECHConfig{
append([]byte(nil), s.ECH...),
}
}
func (s *SVCBECHConfig) unpack(b []byte) error {
s.ECH = append([]byte(nil), b...)
return nil
}
func (s *SVCBECHConfig) parse(b string) error {
x, err := fromBase64([]byte(b))
if err != nil {
return errors.New("dns: svcbech: bad base64 ech")
}
s.ECH = x
return nil
}
// SVCBIPv6Hint pair suggests an IPv6 address which may be used to open connections
// if A and AAAA record responses for SVCB's Target domain haven't been received.
// In that case, optionally, A and AAAA requests can be made, after which the
// connection to the hinted IP address may be terminated and a new connection may be opened.
// Basic use pattern for creating an ipv6hint option:
//
// h := new(dns.HTTPS)
// h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
// e := new(dns.SVCBIPv6Hint)
// e.Hint = []net.IP{net.ParseIP("2001:db8::1")}
// h.Value = append(h.Value, e)
type SVCBIPv6Hint struct {
Hint []net.IP
}
func (*SVCBIPv6Hint) Key() SVCBKey { return SVCB_IPV6HINT }
func (s *SVCBIPv6Hint) len() int { return 16 * len(s.Hint) }
func (s *SVCBIPv6Hint) pack() ([]byte, error) {
b := make([]byte, 0, 16*len(s.Hint))
for _, e := range s.Hint {
if len(e) != net.IPv6len || e.To4() != nil {
return nil, errors.New("dns: svcbipv6hint: expected ipv6, hint is ipv4")
}
b = append(b, e...)
}
return b, nil
}
func (s *SVCBIPv6Hint) unpack(b []byte) error {
if len(b) == 0 || len(b)%16 != 0 {
return errors.New("dns: svcbipv6hint: ipv6 address byte array length not a multiple of 16")
}
x := make([]net.IP, 0, len(b)/16)
for i := 0; i < len(b); i += 16 {
ip := net.IP(b[i : i+16])
if ip.To4() != nil {
return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4")
}
x = append(x, ip)
}
s.Hint = x
return nil
}
func (s *SVCBIPv6Hint) String() string {
str := make([]string, len(s.Hint))
for i, e := range s.Hint {
if x := e.To4(); x != nil {
return "<nil>"
}
str[i] = e.String()
}
return strings.Join(str, ",")
}
func (s *SVCBIPv6Hint) parse(b string) error {
str := strings.Split(b, ",")
dst := make([]net.IP, len(str))
for i, e := range str {
ip := net.ParseIP(e)
if ip == nil {
return errors.New("dns: svcbipv6hint: bad ip")
}
if ip.To4() != nil {
return errors.New("dns: svcbipv6hint: expected ipv6, got ipv4-mapped-ipv6")
}
dst[i] = ip
}
s.Hint = dst
return nil
}
func (s *SVCBIPv6Hint) copy() SVCBKeyValue {
hint := make([]net.IP, len(s.Hint))
for i, ip := range s.Hint {
hint[i] = copyIP(ip)
}
return &SVCBIPv6Hint{
Hint: hint,
}
}
// SVCBDoHPath pair is used to indicate the URI template that the
// clients may use to construct a DNS over HTTPS URI.
//
// See RFC xxxx (https://datatracker.ietf.org/doc/html/draft-ietf-add-svcb-dns-02)
// and RFC yyyy (https://datatracker.ietf.org/doc/html/draft-ietf-add-ddr-06).
//
// A basic example of using the dohpath option together with the alpn
// option to indicate support for DNS over HTTPS on a certain path:
//
// s := new(dns.SVCB)
// s.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}
// e := new(dns.SVCBAlpn)
// e.Alpn = []string{"h2", "h3"}
// p := new(dns.SVCBDoHPath)
// p.Template = "/dns-query{?dns}"
// s.Value = append(s.Value, e, p)
//
// The parsing currently doesn't validate that Template is a valid
// RFC 6570 URI template.
type SVCBDoHPath struct {
Template string
}
func (*SVCBDoHPath) Key() SVCBKey { return SVCB_DOHPATH }
func (s *SVCBDoHPath) String() string { return svcbParamToStr([]byte(s.Template)) }
func (s *SVCBDoHPath) len() int { return len(s.Template) }
func (s *SVCBDoHPath) pack() ([]byte, error) { return []byte(s.Template), nil }
func (s *SVCBDoHPath) unpack(b []byte) error {
s.Template = string(b)
return nil
}
func (s *SVCBDoHPath) parse(b string) error {
template, err := svcbParseParam(b)
if err != nil {
return fmt.Errorf("dns: svcbdohpath: %w", err)
}
s.Template = string(template)
return nil
}
func (s *SVCBDoHPath) copy() SVCBKeyValue {
return &SVCBDoHPath{
Template: s.Template,
}
}
// SVCBLocal pair is intended for experimental/private use. The key is recommended
// to be in the range [SVCB_PRIVATE_LOWER, SVCB_PRIVATE_UPPER].
// Basic use pattern for creating a keyNNNNN option:
//
// h := new(dns.HTTPS)
// h.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeHTTPS, Class: dns.ClassINET}
// e := new(dns.SVCBLocal)
// e.KeyCode = 65400
// e.Data = []byte("abc")
// h.Value = append(h.Value, e)
type SVCBLocal struct {
KeyCode SVCBKey // Never 65535 or any assigned keys.
Data []byte // All byte sequences are allowed.
}
func (s *SVCBLocal) Key() SVCBKey { return s.KeyCode }
func (s *SVCBLocal) String() string { return svcbParamToStr(s.Data) }
func (s *SVCBLocal) pack() ([]byte, error) { return append([]byte(nil), s.Data...), nil }
func (s *SVCBLocal) len() int { return len(s.Data) }
func (s *SVCBLocal) unpack(b []byte) error {
s.Data = append([]byte(nil), b...)
return nil
}
func (s *SVCBLocal) parse(b string) error {
data, err := svcbParseParam(b)
if err != nil {
return fmt.Errorf("dns: svcblocal: svcb private/experimental key %w", err)
}
s.Data = data
return nil
}
func (s *SVCBLocal) copy() SVCBKeyValue {
return &SVCBLocal{s.KeyCode,
append([]byte(nil), s.Data...),
}
}
func (rr *SVCB) String() string {
s := rr.Hdr.String() +
strconv.Itoa(int(rr.Priority)) + " " +
sprintName(rr.Target)
for _, e := range rr.Value {
s += " " + e.Key().String() + "=\"" + e.String() + "\""
}
return s
}
// areSVCBPairArraysEqual checks if SVCBKeyValue arrays are equal after sorting their
// copies. arrA and arrB have equal lengths, otherwise zduplicate.go wouldn't call this function.
func areSVCBPairArraysEqual(a []SVCBKeyValue, b []SVCBKeyValue) bool {
a = append([]SVCBKeyValue(nil), a...)
b = append([]SVCBKeyValue(nil), b...)
sort.Slice(a, func(i, j int) bool { return a[i].Key() < a[j].Key() })
sort.Slice(b, func(i, j int) bool { return b[i].Key() < b[j].Key() })
for i, e := range a {
if e.Key() != b[i].Key() {
return false
}
b1, err1 := e.pack()
b2, err2 := b[i].pack()
if err1 != nil || err2 != nil || !bytes.Equal(b1, b2) {
return false
}
}
return true
}
// svcbParamStr converts the value of an SVCB parameter into a DNS presentation-format string.
func svcbParamToStr(s []byte) string {
var str strings.Builder
str.Grow(4 * len(s))
for _, e := range s {
if ' ' <= e && e <= '~' {
switch e {
case '"', ';', ' ', '\\':
str.WriteByte('\\')
str.WriteByte(e)
default:
str.WriteByte(e)
}
} else {
str.WriteString(escapeByte(e))
}
}
return str.String()
}
// svcbParseParam parses a DNS presentation-format string into an SVCB parameter value.
func svcbParseParam(b string) ([]byte, error) {
data := make([]byte, 0, len(b))
for i := 0; i < len(b); {
if b[i] != '\\' {
data = append(data, b[i])
i++
continue
}
if i+1 == len(b) {
return nil, errors.New("escape unterminated")
}
if isDigit(b[i+1]) {
if i+3 < len(b) && isDigit(b[i+2]) && isDigit(b[i+3]) {
a, err := strconv.ParseUint(b[i+1:i+4], 10, 8)
if err == nil {
i += 4
data = append(data, byte(a))
continue
}
}
return nil, errors.New("bad escaped octet")
} else {
data = append(data, b[i+1])
i += 2
}
}
return data, nil
}