package main import ( "bufio" "errors" "flag" "fmt" "io" "io/ioutil" "log" "net" "os" "os/exec" "path" "strconv" "strings" "time" dockerClient "github.com/fsouza/go-dockerclient" ) var ( SYSFS string = "/sys/fs/cgroup" PROCS string = "cgroup.procs" CGROUP_PROC string = "/proc/%d/cgroup" INTERVAL time.Duration = 1000 ) type Context struct { Args []string Cgroups []string AllCgroups bool Logs bool Notify bool Name string Env bool Rm bool Id string NotifySocket string Cmd *exec.Cmd Pid int PidFile string Client *dockerClient.Client } type strlistFlag []string func (f strlistFlag) String() string { return strings.Join(f, ",") } func (f *strlistFlag) Set(value string) error { *f = append(*f, value) return nil } func setupEnvironment(c *Context) { newArgs := []string{} if c.Notify && len(c.NotifySocket) > 0 { newArgs = append(newArgs, "-e", fmt.Sprintf("NOTIFY_SOCKET=%s", c.NotifySocket)) newArgs = append(newArgs, "-v", fmt.Sprintf("%s:%s", c.NotifySocket, c.NotifySocket)) } else { c.Notify = false } if c.Env { for _, val := range os.Environ() { if !strings.HasPrefix(val, "HOME=") && !strings.HasPrefix(val, "PATH=") { newArgs = append(newArgs, "-e", val) } } } if len(newArgs) > 0 { c.Args = append(newArgs, c.Args...) } } func parseContext(args []string) (*Context, error) { c := &Context{ Logs: true, AllCgroups: false, } flags := flag.NewFlagSet("systemd-docker", flag.ContinueOnError) var flCgroups strlistFlag flags.StringVar(&c.PidFile, "pid-file", "", "pipe file") flags.BoolVar(&c.Logs, "logs", true, "pipe logs") flags.BoolVar(&c.Notify, "notify", false, "setup systemd notify for container") flags.BoolVar(&c.Env, "env", false, "inherit environment variable") flags.Var(&flCgroups, "cgroups", "cgroups to take ownership of or 'all' for all cgroups available") err := flags.Parse(args) if err != nil { return nil, err } foundD := false var name string runArgs := flags.Args() if len(runArgs) == 0 || runArgs[0] != "run" { log.Println("Args:", runArgs) return nil, errors.New("run not found in arguments") } runArgs = runArgs[1:] newArgs := make([]string, 0, len(runArgs)) for i, arg := range runArgs { /* This is tedious, but flag can't ignore unknown flags and I don't want to define them all */ add := true switch { case arg == "-rm" || arg == "--rm": c.Rm = true add = false case arg == "-d" || arg == "-detach" || arg == "--detach": foundD = true case strings.HasPrefix(arg, "-name") || strings.HasPrefix(arg, "--name"): if strings.Contains(arg, "=") { name = strings.SplitN(arg, "=", 2)[1] } else if len(runArgs) > i+1 { name = runArgs[i+1] } } if add { newArgs = append(newArgs, arg) } } if !foundD { newArgs = append([]string{"-d"}, newArgs...) } c.Name = name c.NotifySocket = os.Getenv("NOTIFY_SOCKET") c.Args = newArgs c.Cgroups = []string(flCgroups) for _, val := range c.Cgroups { if val == "all" { c.Cgroups = nil c.AllCgroups = true break } } setupEnvironment(c) return c, nil } func lookupNamedContainer(c *Context) error { client, err := getClient(c) if err != nil { return err } container, err := client.InspectContainerWithOptions( dockerClient.InspectContainerOptions{ID: c.Name}) if _, ok := err.(*dockerClient.NoSuchContainer); ok { return nil } if err != nil || container == nil { return err } if container.State.Running { c.Id = container.ID c.Pid = container.State.Pid return nil } else if c.Rm { return client.RemoveContainer(dockerClient.RemoveContainerOptions{ ID: container.ID, Force: true, }) } else { client, err := getClient(c) if err != nil { return err } err = client.StartContainer(container.ID, container.HostConfig) if err != nil { return err } container, err = client.InspectContainerWithOptions( dockerClient.InspectContainerOptions{ID: c.Name}) if err != nil { return err } c.Id = container.ID c.Pid = container.State.Pid return nil } } func launchContainer(c *Context) error { args := append([]string{"run"}, c.Args...) c.Cmd = exec.Command("docker", args...) errorPipe, err := c.Cmd.StderrPipe() if err != nil { return err } outputPipe, err := c.Cmd.StdoutPipe() if err != nil { return err } err = c.Cmd.Start() if err != nil { return err } go io.Copy(os.Stderr, errorPipe) bytes, err := ioutil.ReadAll(outputPipe) if err != nil { return err } c.Id = strings.TrimSpace(string(bytes)) err = c.Cmd.Wait() if err != nil { return err } if !c.Cmd.ProcessState.Success() { return err } c.Pid, err = getContainerPid(c) return err } func runContainer(c *Context) error { if len(c.Name) > 0 { err := lookupNamedContainer(c) if err != nil { return err } } if len(c.Id) == 0 { err := launchContainer(c) if err != nil { return err } } if c.Pid == 0 { return errors.New("Failed to launch container, pid is 0") } return nil } func getClient(c *Context) (*dockerClient.Client, error) { if c.Client != nil { return c.Client, nil } endpoint := os.Getenv("DOCKER_HOST") if len(endpoint) == 0 { endpoint = "unix:///var/run/docker.sock" } return dockerClient.NewVersionedClient(endpoint, "1.12") } func getContainerPid(c *Context) (int, error) { client, err := getClient(c) if err != nil { return 0, err } container, err := client.InspectContainer(c.Id) if err != nil { return 0, err } if container == nil { return 0, fmt.Errorf("Failed to find container %s", c.Id) } // if container.State.Pid <= 0 { // return 0, fmt.Errorf("Pid is %d for container %s: %+v", container.State.Pid, c.Id, container.State) // } return container.State.Pid, nil } func getCgroupsForPid(pid int) (map[string]string, error) { file, err := os.Open(fmt.Sprintf(CGROUP_PROC, pid)) if err != nil { return nil, err } ret := map[string]string{} scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.SplitN(scanner.Text(), ":", 3) if len(line) != 3 { continue } ret[line[1]] = line[2] } if err := scanner.Err(); err != nil { return nil, err } return ret, nil } func constructCgroupPath(cgroupName string, cgroupPath string) string { if cgroupName == "" { cgroupName = "unified" } return path.Join(SYSFS, strings.TrimPrefix(cgroupName, "name="), cgroupPath, PROCS) } func getCgroupPids(cgroupName string, cgroupPath string) ([]string, error) { ret := []string{} file, err := os.Open(constructCgroupPath(cgroupName, cgroupPath)) if err != nil { return nil, err } scanner := bufio.NewScanner(file) for scanner.Scan() { ret = append(ret, strings.TrimSpace(scanner.Text())) } if err = scanner.Err(); err != nil { return nil, err } return ret, nil } func writePid(pid string, path string) error { return ioutil.WriteFile(path, []byte(pid), 0644) } func moveCgroups(c *Context) (bool, error) { moved := false currentCgroups, err := getCgroupsForPid(os.Getpid()) if err != nil { return false, err } containerCgroups, err := getCgroupsForPid(c.Pid) if err != nil { return false, err } var ns []string if c.AllCgroups || c.Cgroups == nil || len(c.Cgroups) == 0 { ns = make([]string, 0, len(containerCgroups)) for value := range containerCgroups { ns = append(ns, value) } } else { ns = c.Cgroups } for _, nsName := range ns { currentPath, ok := currentCgroups[nsName] if !ok { continue } containerPath, ok := containerCgroups[nsName] if !ok { continue } if currentPath == containerPath || containerPath == "/" { continue } pids, err := getCgroupPids(nsName, containerPath) if err != nil { return false, err } for _, pid := range pids { pidInt, err := strconv.Atoi(pid) if err != nil { continue } if pidDied(pidInt) { continue } currentFullPath := constructCgroupPath(nsName, currentPath) log.Printf("Moving pid %s to %s\n", pid, currentFullPath) err = writePid(pid, currentFullPath) if err != nil { return false, err } moved = true } } return moved, nil } func pidDied(pid int) bool { _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) return os.IsNotExist(err) } func notify(c *Context) error { if pidDied(c.Pid) { return errors.New("Container exited before we could notify systemd") } if len(c.NotifySocket) == 0 { return nil } conn, err := net.Dial("unixgram", c.NotifySocket) if err != nil { return err } defer conn.Close() _, err = conn.Write([]byte(fmt.Sprintf("MAINPID=%d", c.Pid))) if err != nil { return err } if pidDied(c.Pid) { conn.Write([]byte(fmt.Sprintf("MAINPID=%d", os.Getpid()))) return errors.New("Container exited before we could notify systemd") } if !c.Notify { _, err = conn.Write([]byte("READY=1")) if err != nil { return err } } return nil } func pidFile(c *Context) error { if len(c.PidFile) == 0 || c.Pid <= 0 { return nil } err := ioutil.WriteFile(c.PidFile, []byte(strconv.Itoa(c.Pid)), 0644) if err != nil { return err } return nil } func pipeLogs(c *Context) error { if !c.Logs { return nil } client, err := getClient(c) if err != nil { return err } err = client.Logs(dockerClient.LogsOptions{ Container: c.Id, Follow: true, Stdout: true, Stderr: true, OutputStream: os.Stdout, ErrorStream: os.Stderr, }) return err } func keepAlive(c *Context) error { if c.Logs || c.Rm { client, err := getClient(c) if err != nil { return err } /* Good old polling... */ for { container, err := client.InspectContainer(c.Id) if err != nil { return err } if container.State.Running { client.WaitContainer(c.Id) } else { return nil } } } return nil } func rmContainer(c *Context) error { if !c.Rm { return nil } client, err := getClient(c) if err != nil { return err } return client.RemoveContainer(dockerClient.RemoveContainerOptions{ ID: c.Id, Force: true, }) } func mainWithArgs(args []string) (*Context, error) { c, err := parseContext(args) if err != nil { return c, err } err = runContainer(c) if err != nil { return c, fmt.Errorf("runContainer: %v", err) } _, err = moveCgroups(c) if err != nil { return c, err } err = notify(c) if err != nil { return c, err } err = pidFile(c) if err != nil { return c, err } go pipeLogs(c) err = keepAlive(c) if err != nil { return c, err } err = rmContainer(c) if err != nil { return c, err } return c, nil } func main() { log.SetFlags(0) _, err := mainWithArgs(os.Args[1:]) if err != nil { log.Fatal(err) } }