Commit 922c5b9a authored by ale's avatar ale
Browse files

initial commit

parents
++++++++
tlswatch
++++++++
tlswatch is a tool to enforce a STARTTLS policy for your mail server,
combining opportunistic TLS discovery and a "sticky"
trust-on-first-use policy. It currently only supports Postfix.
tlswatch will scan your email logs looking for remote mail servers
that support TLS connections, and it will attempt to establish a trust
path to them. If it succeeds, it will generate a new Postfix TLS
policy to lock the trust path for the destination domain.
TLS policy
----------
There are two available enforcement policies:
*strict*
This policy will match the exact fingerprints of the mail servers'
X509 certificates. While more precise, it will generate lots of
conflicts on every key rotation, so it is advised to use it only if
an independent validation mechanism is available.
*ca-pinning*
This policy will match certificates against the top-level signing
CA. For this to work, the remote mail server must send the full
certificate chain on every TLS connection (in case the CA is not
part of the set of well-known authorities distributed with the
system).
It is also possible to restrict policies to a set of manually
specified domains, or to provide a blacklist of domains to exclude.
Key rotation and error reconciliation
-------------------------------------
When a TLS connection error occurs, a decision must be made as to the
whether change is legitimate or not. Since email to the destination
domain will be queued until the condition is resolved, we'd like to
attempt to validate the change automatically. A few mechanisms come to
mind:
* DANE/TLSA to exploit the independent DNSSEC-based channel
* PGP web-of-trust
Manual resolution is always available as a last resort.
package tlswatch
import (
"bytes"
"crypto/sha1"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"log"
"os"
"net"
"net/textproto"
"sync"
"time"
"github.com/golang/groupcache/singleflight"
)
var (
// Command-line flags.
defaultLocalName string
localName = flag.String("ehlo", "", "EHLO host (defaults to local hostname)")
certInfoTTL = flag.Duration("cert-ttl", 7 * 24 * time.Hour, "TTL for certificate cache")
certInfoErrTTL = flag.Duration("cert-err-ttl", 12 * time.Hour, "TTL for certificate errors cache")
// SMTP connection timeout.
connectTimeout = 30 * time.Second
// Database table for the certificate cache.
certsTable = "certs"
)
func init() {
defaultLocalName, _ = os.Hostname()
}
// Certificate information used by the certificate cache.
type CertInfo struct {
Addr string
Certs []*x509.Certificate
Err error
Expires time.Time
}
func (c CertInfo) Expired() bool {
return time.Now().After(c.Expires)
}
func (c CertInfo) Equal(c2 CertInfo) bool {
if c.Addr != c2.Addr || c.Err != c2.Err {
return false
}
if len(c.Certs) != len(c2.Certs) {
return false
}
for idx, crt := range c.Certs {
if !crt.Equal(c2.Certs[idx]) {
return false
}
}
return true
}
// Send a SMTP command and return the response.
func smtpCmd(t *textproto.Conn, expectCode int, format string, args ...interface{}) (int, string, error) {
id, err := t.Cmd(format, args...)
if err != nil {
return 0, "", err
}
t.StartResponse(id)
defer t.EndResponse(id)
code, msg, err := t.ReadResponse(expectCode)
return code, msg, err
}
// GrabCerts returns the X509 certificates served by the remote
// address, using the STARTTLS SMTP extension.
func GrabCerts(addr string) ([]*x509.Certificate, error) {
conn, err := net.DialTimeout("tcp", addr, connectTimeout)
if err != nil {
log.Printf("%s: connect error: %s", addr, err)
return nil, err
}
defer conn.Close()
tconn := textproto.NewConn(conn)
defer tconn.Close()
if _, _, err := tconn.ReadResponse(220); err != nil {
log.Printf("%s: bad salutation: %s", addr, err)
return nil, err
}
// Send EHLO, ignore reply.
ehlo := *localName
if ehlo == "" {
ehlo = defaultLocalName
}
if _, _, err := smtpCmd(tconn, 250, "EHLO %s", ehlo); err != nil {
log.Printf("%s: EHLO error: %s", addr, err)
return nil, err
}
// Send STARTTLS.
if _, _, err := smtpCmd(tconn, 220, "STARTTLS"); err != nil {
log.Printf("%s: STARTTLS error: %s", addr, err)
return nil, err
}
// Establish the TLS connection, and run the initial handshake.
tlsConn := tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
defer tlsConn.Close()
if err := tlsConn.Handshake(); err != nil {
log.Printf("%s: TLS handshake error: %v", addr, err)
return nil, err
}
return tlsConn.ConnectionState().PeerCertificates, nil
}
type grabCertResponse struct {
certs []*x509.Certificate
err error
}
type grabCertRequest struct {
addr string
ch chan grabCertResponse
}
// CertGrabber will retrieve X509 certificates from remote mail
// servers (keeping a cache in the local database), with a limit on
// the number of concurrent requests.
type CertGrabber struct {
db Database
sg singleflight.Group
wg sync.WaitGroup
ch chan grabCertRequest
}
func NewCertGrabber(db Database, nworkers int) *CertGrabber {
g := &CertGrabber{
db: db,
ch: make(chan grabCertRequest, 100),
}
for i := 0; i < nworkers; i++ {
g.wg.Add(1)
go g.grabber()
}
return g
}
func (g *CertGrabber) Close() {
close(g.ch)
g.wg.Wait()
}
func (g *CertGrabber) grabber() {
for req := range g.ch {
addr := net.JoinHostPort(req.addr, "25")
log.Printf("connecting to %s", addr)
certs, err := GrabCerts(addr)
req.ch <- grabCertResponse{certs, err}
}
g.wg.Done()
}
func (g *CertGrabber) grabCert(addr string) ([]*x509.Certificate, error) {
rch := make(chan grabCertResponse, 1)
defer close(rch)
g.ch <- grabCertRequest{addr, rch}
resp := <-rch
return resp.certs, resp.err
}
// GrabCert attempts to retrieve the X509 certificates served by a
// mail server. Results are cached in the database.
func (g *CertGrabber) GrabCert(addr string) ([]*x509.Certificate, error) {
session := g.db.Session()
defer session.Close()
var certinfo CertInfo
certinfoOk := session.Get(certsTable, addr, &certinfo)
if certinfoOk && !certinfo.Expired() {
return certinfo.Certs, certinfo.Err
}
// Wrap the connection with a singleflight.Group to avoid
// multiple updates when a popular CertInfo expires.
certs, err := g.sg.Do(addr, func() (interface{}, error) {
certs, err := g.grabCert(addr)
ttl := *certInfoTTL
if err != nil {
ttl = *certInfoErrTTL
}
newcertinfo := CertInfo{
Addr: addr,
Certs: certs,
Err: err,
Expires: time.Now().Add(ttl),
}
session.Set(certsTable, addr, &newcertinfo)
// Here we can detect certificate changes!
if certinfoOk && !newcertinfo.Equal(certinfo) {
log.Printf("detected certificate change on %s: %+v -> %+v", addr, certinfo, newcertinfo)
}
return certs, err
})
return certs.([]*x509.Certificate), err
}
// Postfix uses SHA1 fingerprints of the public key.
func fpEncode(data []byte) string {
var buf bytes.Buffer
for i, b := range data {
if i > 0 {
buf.WriteString(":")
}
fmt.Fprintf(&buf, "%02X", b)
}
return buf.String()
}
// GetCertFingerprint returns the SHA1 fingerprints of the given
// certificates, encoded as colon-separated hex strings.
func GetCertFingerprint(cert *x509.Certificate) string {
h := sha1.New()
h.Write(cert.RawSubjectPublicKeyInfo)
return fpEncode(h.Sum(nil))
}
// Get the CA certificates in the list.
func GetCACertificates(certs []*x509.Certificate) []*x509.Certificate {
cas := make([]*x509.Certificate, 0, len(certs))
for _, cert := range certs {
if cert.IsCA {
cas = append(cas, cert)
}
}
return cas
}
package main
import (
"flag"
"git.autistici.org/ale/postfix-tlswatch"
"log"
"os"
"time"
)
var (
dbPath = flag.String("db", "/var/lib/postfix-tlswatch/db", "Database directory")
tlsPolicy = flag.String("policy", "ca-pinning", "TLS policy (strict | ca-pinning)")
tlsPolicyMapFile = flag.String("tls-policy-map", "/etc/postfix/maps/tls_policy", "Location of the Postfix tls_policy_map file")
updatePeriod = flag.Duration("update-period", 900 * time.Second, "Update period")
whitelistFile = flag.String("whitelist", "", "Domain whitelist file")
blacklistFile = flag.String("blacklist", "", "Domain blacklist file")
)
func main() {
flag.Parse()
ch := make(chan string, 100)
errCh := make(chan string, 100)
// Setup the interface to Postfix.
pmap := tlswatch.NewPostfixPolicyMap(*tlsPolicyMapFile)
policy, err := tlswatch.GetTlsPolicy(*tlsPolicy)
if err != nil {
log.Fatal(err)
}
// Initialize the database.
db := tlswatch.NewLevelDbDatabase(*dbPath)
pw := tlswatch.NewPolicyWatcher(
policy,
pmap,
db,
tlswatch.Batch(ch, *updatePeriod),
tlswatch.Batch(errCh, *updatePeriod))
var domainWl tlswatch.RegexpList
if *whitelistFile != "" {
domainWl = tlswatch.ParseWildcardsFromFile(*whitelistFile)
}
var domainBl tlswatch.RegexpList
if *blacklistFile != "" {
domainBl = tlswatch.ParseWildcardsFromFile(*blacklistFile)
}
tlswatch.NewScanner(db, ch, errCh, domainWl, domainBl).Scan(os.Stdin)
close(ch)
close(errCh)
pw.Wait()
pmap.Save()
//log.Printf("%+v\n", pw)
db.Close()
}
package tlswatch
// The Database holds information about known domain -> mx mappings
// and the related mx connection data. Every association has a
// freshness metric, to detect configuration changes by allowing old
// relations to expire.
type Database interface {
Session() Session
Close()
}
// Iterator is a database iterator to scan a result range.
type Iterator interface {
Next() bool
Key() string
Value(interface{})
Close()
}
// Session provides a consistent view of the database. Note that
// writes might not be immediately visible within the same Session.
type Session interface {
Get(table, key string, obj interface{}) bool
Set(table, key string, obj interface{})
Del(table, key string)
Scan(table, startKey, endKey string) (Iterator, error)
Close()
}
package tlswatch
import (
"bytes"
"encoding/gob"
"errors"
"sync"
)
// InMemoryDatabase, as one would expect, only stores its data in
// memory. Useful for testing.
type InMemoryDatabase struct {
lock sync.Mutex
tables map[string]map[string][]byte
}
func NewInMemoryDatabase() *InMemoryDatabase {
return &InMemoryDatabase{
tables: make(map[string]map[string][]byte),
}
}
func (db *InMemoryDatabase) Close() {
}
func (db *InMemoryDatabase) Session() Session {
return &inmemorySession{db}
}
type inmemoryKV struct {
key string
value []byte
}
type inmemoryIterator struct {
values []inmemoryKV
pos int
}
func (i *inmemoryIterator) Next() bool {
i.pos++
return i.pos < len(i.values)
}
func (i *inmemoryIterator) Key() string {
return i.values[i.pos].key
}
func (i *inmemoryIterator) Value(obj interface{}) {
gob.NewDecoder(bytes.NewReader(i.values[i.pos].value)).Decode(obj)
}
func (i *inmemoryIterator) Close() {
}
type inmemorySession struct {
db *InMemoryDatabase
}
func (s *inmemorySession) Get(table, key string, obj interface{}) bool {
s.db.lock.Lock()
defer s.db.lock.Unlock()
if t, ok := s.db.tables[table]; ok {
if value, ok := t[key]; ok {
if obj != nil {
gob.NewDecoder(bytes.NewReader(value)).Decode(obj)
}
return true
}
}
return false
}
func (s *inmemorySession) Set(table, key string, obj interface{}) {
s.db.lock.Lock()
defer s.db.lock.Unlock()
var buf bytes.Buffer
gob.NewEncoder(&buf).Encode(obj)
t, ok := s.db.tables[table]
if !ok {
t = make(map[string][]byte)
s.db.tables[table] = t
}
t[key] = buf.Bytes()
}
func (s *inmemorySession) Del(table, key string) {
s.db.lock.Lock()
defer s.db.lock.Unlock()
if t, ok := s.db.tables[table]; ok {
delete(t, key)
}
}
func (s *inmemorySession) Scan(table, startKey, endKey string) (Iterator, error) {
t, ok := s.db.tables[table]
if !ok {
return nil, errors.New("No such table")
}
values := make([]inmemoryKV, 0)
for key, value := range t {
if key >= startKey && key < endKey {
values = append(values, inmemoryKV{key, value})
}
}
return &inmemoryIterator{values, -1}, nil
}
func (s *inmemorySession) Close() {
}
package tlswatch
import (
"bytes"
"encoding/gob"
"fmt"
"log"
"github.com/jmhodges/levigo"
"strings"
)
var (
LruCacheSize = 2 << 20
BloomFilterSize = 10
)
type LevelDbDatabase struct {
db *levigo.DB
cache *levigo.Cache
filter *levigo.FilterPolicy
}
func NewLevelDbDatabase(path string) *LevelDbDatabase {
opts := levigo.NewOptions()
cache := levigo.NewLRUCache(LruCacheSize)
opts.SetCache(cache)
filter := levigo.NewBloomFilter(BloomFilterSize)
opts.SetFilterPolicy(filter)
opts.SetCreateIfMissing(true)
db, err := levigo.Open(path, opts)
if err != nil {
log.Fatal(err)
}
return &LevelDbDatabase{db, cache, filter}
}
func (db *LevelDbDatabase) Close() {
db.db.Close()
db.cache.Close()
db.filter.Close()
}
func (db *LevelDbDatabase) Session() Session {
snap := db.db.NewSnapshot()
ro := levigo.NewReadOptions()
ro.SetSnapshot(snap)
return &levelDbSession{
db: db.db,
snap: snap,
readOpts: ro,
}
}
type levelDbSession struct {
db *levigo.DB
readOpts *levigo.ReadOptions
snap *levigo.Snapshot
wb *levigo.WriteBatch
}
func (s *levelDbSession) makeKey(table, key string) []byte {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s:%s", table, key)
return buf.Bytes()
}
func (s *levelDbSession) Get(table, key string, obj interface{}) bool {
data, err := s.db.Get(s.readOpts, s.makeKey(table, key))
if err != nil || data == nil {
return false
}
if obj != nil {
gob.NewDecoder(bytes.NewReader(data)).Decode(obj)
}
return true
}
func (s *levelDbSession) Set(table, key string, obj interface{}) {
if s.wb == nil {
s.wb = levigo.NewWriteBatch()
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(obj); err != nil {
return
}
s.wb.Put(s.makeKey(table, key), buf.Bytes())
}
func (s *levelDbSession) Del(table, key string) {
if s.wb == nil {
s.wb = levigo.NewWriteBatch()
}
s.wb.Delete(s.makeKey(table, key))
}
func (s *levelDbSession) Scan(table, startKey, endKey string) (Iterator, error) {
iter := s.db.NewIterator(s.readOpts)
iter.Seek(s.makeKey(table, startKey))
return &levelDbIterator{
iter: iter,
endKey: s.makeKey(table, endKey),
}, nil
}
func (s *levelDbSession) Close() {
if s.wb != nil {
if err := s.db.Write(levigo.NewWriteOptions(), s.wb); err != nil {
log.Printf("LevelDB write error: %s", err)
}
s.wb.Close()
}
s.db.ReleaseSnapshot(s.snap)
}
type levelDbIterator struct {
iter *levigo.Iterator
endKey, curValue []byte
curKey string