Skip to content
Snippets Groups Projects
Commit e067d5b3 authored by ale's avatar ale
Browse files

client: support multiple auth servers with fallback on error

parent 0c93c3e9
Branches
No related tags found
No related merge requests found
ACLOCAL_AMFLAGS = -I m4
AM_CXXFLAGS = -std=c++11
pamdir = $(PAMDIR)
......
......@@ -17,22 +17,48 @@
static const char *kAuthApiPath = "/api/1/auth";
struct server_list {
int n;
char **servers;
char *buf;
};
typedef struct server_list *server_list_t;
static void server_list_parse(server_list_t sl, const char *s) {
char *saveptr = NULL, *sp, *tok;
sl->n = 0;
sl->servers = (char **)malloc(sizeof(char *));
sl->buf = strdup(s);
for (sp = sl->buf; (tok = strtok_r(sp, ",", &saveptr)) != NULL; sp = NULL) {
sl->servers = (char **)realloc(sl->servers, sizeof(char *) * (sl->n + 1));
sl->servers[sl->n++] = tok;
}
}
static void server_list_free(server_list_t sl) {
free(sl->buf);
free(sl->servers);
}
struct auth_client {
CURL *c;
const char *service;
const char *server;
int https;
//const char *server;
struct server_list server_list;
};
static int auth_client_set_proto(auth_client_t ac, const char *proto) {
char url[strlen(ac->server) + 32];
sprintf(url, "%s://%s%s", proto, ac->server, kAuthApiPath);
static int auth_client_set_url(auth_client_t ac, const char *server) {
char url[strlen(server) + 32];
sprintf(url, "http%s://%s%s", ac->https ? "s" : "", server, kAuthApiPath);
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_URL, url));
return AC_OK;
}
static int curl_initialized = 0;
auth_client_t auth_client_new(const char *service, const char *server) {
auth_client_t auth_client_new(const char *service, const char *servers) {
auth_client_t ac = (auth_client_t)malloc(sizeof(struct auth_client));
if (!curl_initialized) {
......@@ -40,12 +66,12 @@ auth_client_t auth_client_new(const char *service, const char *server) {
curl_initialized = 1;
}
server_list_parse(&ac->server_list, servers);
ac->service = service;
ac->server = server;
ac->c = curl_easy_init();
ac->https = 0;
curl_easy_setopt(ac->c, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(ac->c, CURLOPT_TIMEOUT, 60);
auth_client_set_proto(ac, "http");
return ac;
}
......@@ -69,10 +95,6 @@ int auth_client_set_certificate(auth_client_t ac,
if (!file_exists(ca_file) || !file_exists(crt_file) || !file_exists(key_file)) {
return AC_ERR_FILE_NOT_FOUND;
}
err = auth_client_set_proto(ac, "https");
if (err != AC_OK) {
return err;
}
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_CAINFO, ca_file));
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEYTYPE, "PEM"));
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEY, key_file));
......@@ -81,11 +103,13 @@ int auth_client_set_certificate(auth_client_t ac,
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYPEER, 1L));
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYHOST, 0L));
CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1));
ac->https = 1;
return AC_OK;
}
void auth_client_free(auth_client_t ac) {
curl_easy_cleanup(ac->c);
server_list_free(&ac->server_list);
free(ac);
}
......@@ -102,6 +126,8 @@ const char *auth_client_strerror(int err) {
return "Bad server response";
case AC_ERR_FILE_NOT_FOUND:
return "Certificate or CA file not found";
case AC_ERR_NO_SERVERS:
return "No servers could be reached";
default:
return "Unknown error";
}
......@@ -159,9 +185,8 @@ int auth_client_authenticate(auth_client_t ac,
const char *shard) {
struct curl_slist *headers = NULL;
struct cbuf form;
struct cbuf responsebuf;
CURLcode res;
int retval;
int i, retval = AC_ERR_NO_SERVERS;
// Build the POST request contents.
cbuf_init(&form, 256);
......@@ -186,28 +211,38 @@ int auth_client_authenticate(auth_client_t ac,
curl_slist_append(headers, "Content-Type: application/x-form-www-urlencoded");
curl_easy_setopt(ac->c, CURLOPT_HTTPHEADER, headers);
cbuf_init(&responsebuf, 64);
curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback);
curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf);
res = curl_easy_perform(ac->c);
if (res == CURLE_OK) {
// Check the auth server response.
if (!strncmp(responsebuf.buf, "OK", 2)) {
retval = AC_OK;
} else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) {
retval = AC_ERR_OTP_REQUIRED;
} else if (!strncmp(responsebuf.buf, "ERROR", 5)) {
retval = AC_ERR_AUTHENTICATION_FAILURE;
// Iterate over the known servers. We create a new response buffer
// for each request just in case we get a partial transfer error.
for (i = 0; i < ac->server_list.n; i++) {
char *server = ac->server_list.servers[i];
struct cbuf responsebuf;
cbuf_init(&responsebuf, 64);
curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback);
curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf);
auth_client_set_url(ac, server);
res = curl_easy_perform(ac->c);
if (res == CURLE_OK) {
// Check the auth server response.
if (!strncmp(responsebuf.buf, "OK", 2)) {
retval = AC_OK;
} else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) {
retval = AC_ERR_OTP_REQUIRED;
} else if (!strncmp(responsebuf.buf, "ERROR", 5)) {
retval = AC_ERR_AUTHENTICATION_FAILURE;
} else {
retval = AC_ERR_BAD_RESPONSE;
}
break;
} else {
retval = AC_ERR_BAD_RESPONSE;
retval = auth_client_err_from_curl(res);
}
} else {
retval = auth_client_err_from_curl(res);
cbuf_free(&responsebuf);
}
cbuf_free(&form);
cbuf_free(&responsebuf);
curl_slist_free_all(headers);
return retval;
......
......@@ -11,11 +11,12 @@ typedef struct auth_client* auth_client_t;
#define AC_ERR_OTP_REQUIRED -2
#define AC_ERR_BAD_RESPONSE -3
#define AC_ERR_FILE_NOT_FOUND -4
#define AC_ERR_NO_SERVERS -5
#define AC_ERR_CURL_BASE -100
#define auth_client_err_to_curl(e) (-(e)+(AC_ERR_CURL_BASE))
#define auth_client_err_from_curl(e) ((AC_ERR_CURL_BASE)-(e))
auth_client_t auth_client_new(const char *service, const char *server);
auth_client_t auth_client_new(const char *service, const char *servers);
void auth_client_free(auth_client_t ac);
const char *auth_client_strerror(int err);
void auth_client_set_verbose(auth_client_t ac, int verbose);
......
// Tests for auth_client.c.
#include <string>
#include <stdlib.h>
#include "gtest/gtest.h"
extern "C" {
......@@ -26,16 +27,23 @@ class AuthClientTest
: public ::testing::Test
{
public:
AuthClientTest() {
ac = auth_client_new("service", server);
AuthClientTest(std::string serverprefix)
: server_(serverprefix + std::string(server))
{
ac = auth_client_new("service", server_.c_str());
assert(ac != NULL);
auth_client_set_verbose(ac, 1);
}
AuthClientTest()
: AuthClientTest("")
{}
virtual ~AuthClientTest() {
auth_client_free(ac);
}
std::string server_;
auth_client_t ac;
};
......@@ -123,6 +131,28 @@ TEST_F(AuthClientTest, SSLFailsWithBadCAServerSide) {
EXPECT_NE(AC_OK, result) << "authenticate() didn't fail, server=" << server;
}
class AuthClientServerFallbackTest
: public AuthClientTest
{
public:
AuthClientServerFallbackTest()
: AuthClientTest("127.8.8.8:1024,")
{}
};
// Test RPC fallback if first server is bad.
TEST_F(AuthClientServerFallbackTest, AuthOK) {
int result;
result = auth_client_set_certificate(ac, ssl_ca, ssl_cert, ssl_key);
EXPECT_EQ(AC_OK, result) << "set_certificate() error: " << auth_client_strerror(result);
result = auth_client_authenticate(ac, "user", "pass", NULL, "127.0.0.1", NULL);
EXPECT_EQ(AC_OK, result) << "authenticate() error: " << auth_client_strerror(result)
<< ", server=" << server;
}
int main(int argc, char **argv) {
server = getenv("AUTH_SERVER");
if (server == NULL) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment