#include "config.h"
#include <stdio.h>
#ifdef HAVE_STDLIB_H
#include <stdlib.h>
#endif
#ifdef HAVE_MEMORY_H
#include <memory.h>
#endif
#include <sys/stat.h>
#include <curl/curl.h>
#include "auth_client.h"

#define CURL_CHECK(x) { \
  int _err = (x); if (_err != CURLE_OK) { return auth_client_err_from_curl(_err); } \
}

static const char *kAuthApiPath = "/api/1/auth";

struct auth_client {
  CURL *c;
  const char *service;
  const char *server;
};

static int auth_client_set_proto(auth_client_t ac, const char *proto) {
  char url[strlen(ac->server) + 32];
  sprintf(url, "%s://%s%s", proto, ac->server, kAuthApiPath);
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_URL, url));
  return AC_OK;
}

static int curl_initialized = 0;

auth_client_t auth_client_new(const char *service, const char *server) {
  auth_client_t ac = (auth_client_t)malloc(sizeof(struct auth_client));

  if (!curl_initialized) {
    curl_global_init(CURL_GLOBAL_DEFAULT);
    curl_initialized = 1;
  }

  ac->service = service;
  ac->server = server;
  ac->c = curl_easy_init();
  curl_easy_setopt(ac->c, CURLOPT_NOSIGNAL, 1);
  curl_easy_setopt(ac->c, CURLOPT_TIMEOUT, 60);
  auth_client_set_proto(ac, "http");
  return ac;
}

static int file_exists(const char *path) {
  struct stat stbuf;
  if (stat(path, &stbuf) < 0) {
    return 0;
  }
  return 1;
}

void auth_client_set_verbose(auth_client_t ac, int verbose) {
  curl_easy_setopt(ac->c, CURLOPT_VERBOSE, verbose);
}

int auth_client_set_certificate(auth_client_t ac,
                                const char *ca_file,
                                const char *crt_file,
                                const char *key_file) {
  int err;
  if (!file_exists(ca_file) || !file_exists(crt_file) || !file_exists(key_file)) {
    return AC_ERR_FILE_NOT_FOUND;
  }
  err = auth_client_set_proto(ac, "https");
  if (err != AC_OK) {
    return err;
  }
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_CAINFO, ca_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEYTYPE, "PEM"));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEY, key_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLCERTTYPE, "PEM"));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLCERT, crt_file));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYPEER, 1L));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYHOST, 0L));
  CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1));
  return AC_OK;
}

void auth_client_free(auth_client_t ac) {
  curl_easy_cleanup(ac->c);
  free(ac);
}

const char *auth_client_strerror(int err) {
  if (err < AC_ERR_CURL_BASE) {
    return curl_easy_strerror(auth_client_err_to_curl(err));
  }
  switch (err) {
  case AC_ERR_AUTHENTICATION_FAILURE:
    return "Authentication failure";
  case AC_ERR_OTP_REQUIRED:
    return "OTP required";
  case AC_ERR_BAD_RESPONSE:
    return "Bad server response";
  case AC_ERR_FILE_NOT_FOUND:
    return "Certificate or CA file not found";
  default:
    return "Unknown error";
  }
}

/*
 * A dynamically sized memory buffer that can be appended to, and will
 * grow accordingly. It is optimized to perform well for a specific
 * size (the initial allocation).
 */
struct cbuf {
  char *buf;
  size_t alloc, size;
};

static void cbuf_init(struct cbuf *cbuf, size_t alloc) {
  cbuf->buf = (char *)malloc(alloc);
  cbuf->alloc = alloc;
  cbuf->size = 0;
}

static void cbuf_free(struct cbuf *cbuf) {
  free(cbuf->buf);
}

static void cbuf_append(struct cbuf *cbuf, void *data, size_t size) {
  // Resize if necessary.
  size_t required_alloc = cbuf->size + size + 1;
  if (required_alloc > cbuf->alloc) {
    size_t new_alloc = cbuf->alloc;
    while (new_alloc < required_alloc) {
      new_alloc *= 2;
    }
    cbuf->buf = (char *)realloc(cbuf->buf, new_alloc);
    cbuf->alloc = new_alloc;
  }

  // Append data to the buffer.
  memcpy(cbuf->buf + cbuf->size, data, size);
  cbuf->size += size;
  cbuf->buf[cbuf->size] = '\0';
}

static char *quote(const char *s) {
  char *out = (char *)malloc(strlen(s) * 3 + 1), *optr;
  for (optr = out; *s; s++) {
    switch (*s) {
    case ';':
    case '/':
    case '?':
    case ':':
    case '@':
    case '&':
    case '=':
    case '+':
    case '$':
    case ',':
      sprintf(optr, "%%%02X", (int)(*s));
      optr += 3;
      break;
    default:
      *optr++ = *s;
    }
  }
  return out;
}

static size_t responsebuf_callback(void *contents, size_t size, size_t nmemb, void *userp) {
  size_t realsize = size * nmemb;
  struct cbuf *cbuf = (struct cbuf *)userp;

  cbuf_append(cbuf, contents, realsize);
  return realsize;
}

static void post_field_add(struct cbuf *form_data, const char *key, const char *value) {
  char *quoted_key = quote(key), *quoted_value = quote(value);
  if (form_data->size != 0) {
    cbuf_append(form_data, "&", 1);
  }
  cbuf_append(form_data, quoted_key, strlen(quoted_key));
  cbuf_append(form_data, "=", 1);
  cbuf_append(form_data, quoted_value, strlen(quoted_value));
  free(quoted_key);
  free(quoted_value);
}

int auth_client_authenticate(auth_client_t ac,
                             const char *username,
                             const char *password,
                             const char *otp_token,
                             const char *source_ip,
			     const char *shard) {
  struct curl_slist *headers = NULL;
  struct cbuf form;
  struct cbuf responsebuf;
  CURLcode res;
  int retval;

  // Build the POST request contents.
  cbuf_init(&form, 256);
  post_field_add(&form, "service", ac->service);
  post_field_add(&form, "username", username);
  if (password) {
    post_field_add(&form, "password", password);
  }
  if (otp_token) {
    post_field_add(&form, "otp", otp_token);
  }
  if (source_ip) {
    post_field_add(&form, "source_ip", source_ip);
  }
  if (shard) {
    post_field_add(&form, "shard", shard);
  }
  curl_easy_setopt(ac->c, CURLOPT_POSTFIELDS, form.buf);

  // Set request headers.
  curl_slist_append(headers, "Content-Type: application/x-form-www-urlencoded");
  curl_easy_setopt(ac->c, CURLOPT_HTTPHEADER, headers);

  cbuf_init(&responsebuf, 64);
  curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback);
  curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf);

  res = curl_easy_perform(ac->c);
  if (res == CURLE_OK) {
    // Check the auth server response.
    if (!strncmp(responsebuf.buf, "OK", 2)) {
      retval = AC_OK;
    } else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) {
      retval = AC_ERR_OTP_REQUIRED;
    } else if (!strncmp(responsebuf.buf, "ERROR", 5)) {
      retval = AC_ERR_AUTHENTICATION_FAILURE;
    } else {
      retval = AC_ERR_BAD_RESPONSE;
    }
  } else {
    retval = auth_client_err_from_curl(res);
  }

  cbuf_free(&form);
  cbuf_free(&responsebuf);
  curl_slist_free_all(headers);

  return retval;
}