package util

import (
	"context"
	"sync"

	"google.golang.org/grpc"
)

// ConnCache caches GRPC connections for a set of (identical)
// backends, all sharing a common set of options.
type ConnCache struct {
	mx    sync.Mutex
	conns map[string]*grpc.ClientConn
	opts  []grpc.DialOption
}

// NewConnCache returns a new GRPC connection cache with the specified options.
func NewConnCache(opts ...grpc.DialOption) *ConnCache {
	return &ConnCache{
		conns: make(map[string]*grpc.ClientConn),
		opts:  opts,
	}
}

// Get returns a (potentially cached) GRPC connection.
func (c *ConnCache) Get(ctx context.Context, addr string) (*grpc.ClientConn, error) {
	c.mx.Lock()
	defer c.mx.Unlock()

	conn, ok := c.conns[addr]
	if ok {
		return conn, nil
	}

	conn, err := grpc.DialContext(ctx, addr, c.opts...)
	if err != nil {
		return nil, err
	}
	c.conns[addr] = conn
	return conn, nil
}

// Drop notifies the cache of an error with the connection, which will
// be (potentially) closed and dropped from the cache.
func (c *ConnCache) Drop(addr string, conn *grpc.ClientConn) {
	c.mx.Lock()
	if curConn, ok := c.conns[addr]; ok && curConn == conn {
		delete(c.conns, addr)
		conn.Close()
	}
	c.mx.Unlock()
}