Skip to content
Snippets Groups Projects
auth_client.c 6.78 KiB
#include "config.h"
#include <stdio.h>
#ifdef HAVE_STDLIB_H
#include <stdlib.h>
#endif
#ifdef HAVE_MEMORY_H
#include <memory.h>
#endif
#include <sys/stat.h>
#include <curl/curl.h>
#include "auth_client.h"
#include "cbuf.h"

#define CURL_CHECK(x) { \
  int _err = (x); if (_err != CURLE_OK) { return auth_client_err_from_curl(_err); } \
}

static const char *kAuthApiPath = "/api/1/auth";

struct server_list {
  int n;
  char **servers;
  char *buf;
};

typedef struct server_list *server_list_t;

static void server_list_parse(server_list_t sl, const char *s) {
  char *saveptr = NULL, *sp, *tok;
  sl->n = 0;
  sl->servers = (char **)malloc(sizeof(char *));
  sl->buf = strdup(s);
  for (sp = sl->buf; (tok = strtok_r(sp, ",", &saveptr)) != NULL; sp = NULL) {
    sl->servers = (char **)realloc(sl->servers, sizeof(char *) * (sl->n + 1));
    sl->servers[sl->n++] = tok;
  }
}

static void server_list_free(server_list_t sl) {
  free(sl->buf);
  free(sl->servers);
}

struct auth_client {
  CURL *c;
  const char *service;
  int https;
  //const char *server;
  struct server_list server_list;
};

static int auth_client_set_url(auth_client_t ac, const char *server) {
  char url[strlen(server) + 32];
  sprintf(url, "http%s://%s%s", ac->https ? "s" : "", server, kAuthApiPath);
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_URL, url));
  return AC_OK;
}

static int curl_initialized = 0;

auth_client_t auth_client_new(const char *service, const char *servers) {
  auth_client_t ac = (auth_client_t)malloc(sizeof(struct auth_client));

  if (!curl_initialized) {
    curl_global_init(CURL_GLOBAL_DEFAULT);
    curl_initialized = 1;
  }

  server_list_parse(&ac->server_list, servers);
  ac->service = service;
  ac->c = curl_easy_init();
  ac->https = 0;
  curl_easy_setopt(ac->c, CURLOPT_NOSIGNAL, 1);
  curl_easy_setopt(ac->c, CURLOPT_TIMEOUT, 60);
  return ac;
}

static int file_exists(const char *path) {
  struct stat stbuf;
  if (stat(path, &stbuf) < 0) {
    return 0;
  }
  return 1;
}

void auth_client_set_verbose(auth_client_t ac, int verbose) {
  curl_easy_setopt(ac->c, CURLOPT_VERBOSE, verbose);
}

int auth_client_set_certificate(auth_client_t ac,
                                const char *ca_file,
                                const char *crt_file,
                                const char *key_file) {
  int err;
  if (!file_exists(ca_file) || !file_exists(crt_file) || !file_exists(key_file)) {
    return AC_ERR_FILE_NOT_FOUND;
  }
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_CAINFO, ca_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEYTYPE, "PEM"));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEY, key_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLCERTTYPE, "PEM"));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLCERT, crt_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYPEER, 1L));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYHOST, 0L));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1));
  ac->https = 1;
  return AC_OK;
}

void auth_client_free(auth_client_t ac) {
  curl_easy_cleanup(ac->c);
  server_list_free(&ac->server_list);
  free(ac);
}

const char *auth_client_strerror(int err) {
  if (err < AC_ERR_CURL_BASE) {
    return curl_easy_strerror(auth_client_err_to_curl(err));
  }
  switch (err) {
  case AC_ERR_AUTHENTICATION_FAILURE:
    return "Authentication failure";
  case AC_ERR_OTP_REQUIRED:
    return "OTP required";
  case AC_ERR_BAD_RESPONSE:
    return "Bad server response";
  case AC_ERR_FILE_NOT_FOUND:
    return "Certificate or CA file not found";
  case AC_ERR_NO_SERVERS:
    return "No servers could be reached";
  default:
    return "Unknown error";
  }
}

static char *quote(const char *s) {
  char *out = (char *)malloc(strlen(s) * 3 + 1), *optr;
  for (optr = out; *s; s++) {
    switch (*s) {
    case ';':
    case '/':
    case '?':
    case ':':
    case '@':
    case '&':
    case '=':
    case '+':
    case '$':
    case ',':
      sprintf(optr, "%%%02X", (int)(*s));
      optr += 3;
      break;
    default:
      *optr++ = *s;
    }
  }
  *optr = '\0';
  return out;
}

static size_t responsebuf_callback(void *contents, size_t size, size_t nmemb, void *userp) {
  size_t realsize = size * nmemb;
  struct cbuf *cbuf = (struct cbuf *)userp;

  cbuf_append(cbuf, contents, realsize);
  return realsize;
}

static void post_field_add(struct cbuf *form_data, const char *key, const char *value) {
  char *quoted_value = quote(value);
  if (form_data->size != 0) {
    cbuf_append(form_data, "&", 1);
  }
  cbuf_append(form_data, key, strlen(key));
  cbuf_append(form_data, "=", 1);
  cbuf_append(form_data, quoted_value, strlen(quoted_value));
  free(quoted_value);
}

int auth_client_authenticate(auth_client_t ac,
                             const char *username,
                             const char *password,
                             const char *otp_token,
                             const char *source_ip,
			     const char *shard) {
  struct curl_slist *headers = NULL;
  struct cbuf form;
  CURLcode res;
  int i, retval = AC_ERR_NO_SERVERS;

  // Build the POST request contents.
  cbuf_init(&form, 256);
  post_field_add(&form, "service", ac->service);
  post_field_add(&form, "username", username);
  if (password) {
    post_field_add(&form, "password", password);
  }
  if (otp_token) {
    post_field_add(&form, "otp", otp_token);
  }
  if (source_ip) {
    post_field_add(&form, "source_ip", source_ip);
  }
  if (shard) {
    post_field_add(&form, "shard", shard);
  }
  curl_easy_setopt(ac->c, CURLOPT_POSTFIELDS, form.buf);
  curl_easy_setopt(ac->c, CURLOPT_POSTFIELDSIZE, form.size);

  // Set request headers.
  curl_slist_append(headers, "Content-Type: application/x-form-www-urlencoded");
  curl_easy_setopt(ac->c, CURLOPT_HTTPHEADER, headers);

  // Iterate over the known servers. We create a new response buffer
  // for each request just in case we get a partial transfer error.
  for (i = 0; i < ac->server_list.n; i++) {
    char *server = ac->server_list.servers[i];
    struct cbuf responsebuf;

    cbuf_init(&responsebuf, 64);
    curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback);
    curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf);
    auth_client_set_url(ac, server);

    res = curl_easy_perform(ac->c);
    if (res == CURLE_OK) {
      // Check the auth server response.
      if (!strncmp(responsebuf.buf, "OK", 2)) {
        retval = AC_OK;
      } else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) {
        retval = AC_ERR_OTP_REQUIRED;
      } else if (!strncmp(responsebuf.buf, "ERROR", 5)) {
        retval = AC_ERR_AUTHENTICATION_FAILURE;
      } else {
        retval = AC_ERR_BAD_RESPONSE;
      }
      break;
    } else {
      retval = auth_client_err_from_curl(res);
    }

    cbuf_free(&responsebuf);
  }

  cbuf_free(&form);
  curl_slist_free_all(headers);

  return retval;
}