package util

import (
	"flag"
	"fmt"
	"net"
	"strings"
)

type ipList []net.IP

func (l *ipList) String() string {
	var sl []string
	for _, ip := range *l {
		sl = append(sl, ip.String())
	}
	return strings.Join(sl, ",")
}

func (l *ipList) Set(value string) error {
	if ip := net.ParseIP(value); ip != nil {
		*l = append(*l, ip)
		return nil
	}

	// Value is not an IP address, try to resolve it.
	ips, err := net.LookupIP(value)
	if err != nil {
		return fmt.Errorf("unable to parse IP address \"%s\": %v", value, err)
	}
	*l = append(*l, ips...)
	return nil
}

func IPListFlag(name, help string) *[]net.IP {
	var l ipList
	flag.Var(&l, name, help)
	return (*[]net.IP)(&l)
}

type ipFlag net.IP

func (f *ipFlag) String() string {
	return net.IP(*f).String()
}

func (f *ipFlag) Set(value string) error {
	ip := net.ParseIP(value)
	if ip == nil {
		return fmt.Errorf("unable to parse IP address \"%s\"", value)
	}
	*f = ipFlag(ip)
	return nil
}

func IPFlag(name, help string) *net.IP {
	var f ipFlag
	flag.Var(&f, name, help)
	return (*net.IP)(&f)
}