package ldaputil

import (
	"context"
	"time"

	"github.com/cenkalti/backoff"
	"gopkg.in/ldap.v2"

	"git.autistici.org/ai3/go-common/clientutil"
)

// Treat all errors as potential network-level issues, except for a
// whitelist of LDAP protocol level errors that we know are benign.
func isTemporaryLDAPError(err error) bool {
	ldapErr, ok := err.(*ldap.Error)
	if !ok {
		return true
	}
	switch ldapErr.ResultCode {
	case ldap.ErrorNetwork:
		return true
	default:
		return false
	}
}

// Search performs the given search request. It will retry the request
// on temporary errors.
func (p *ConnectionPool) Search(ctx context.Context, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
	var result *ldap.SearchResult
	err := clientutil.Retry(func() error {
		conn, err := p.Get(ctx)
		if err != nil {
			if isTemporaryLDAPError(err) {
				return clientutil.TempError(err)
			}
			return err
		}

		if deadline, ok := ctx.Deadline(); ok {
			conn.SetTimeout(time.Until(deadline))
		}

		result, err = conn.Search(searchRequest)
		if err != nil && isTemporaryLDAPError(err) {
			p.Release(conn, nil)
			return clientutil.TempError(err)
		}
		p.Release(conn, err)
		return err
	}, backoff.WithContext(clientutil.NewExponentialBackOff(), ctx))
	return result, err
}