diff --git a/pam/Makefile.am b/pam/Makefile.am index 3dfff694ba2e20d416a1ac914c96156f1ab57ccc..ff8f0ba05b20272a64b05817c8e01f4506ce93d6 100644 --- a/pam/Makefile.am +++ b/pam/Makefile.am @@ -1,5 +1,6 @@ ACLOCAL_AMFLAGS = -I m4 +AM_CXXFLAGS = -std=c++11 pamdir = $(PAMDIR) diff --git a/pam/auth_client.c b/pam/auth_client.c index 373d318a362298677e75e702b3f07a8c2be59004..98789be6782576b9a68563b57f9f0560d8c82dcd 100644 --- a/pam/auth_client.c +++ b/pam/auth_client.c @@ -17,22 +17,48 @@ static const char *kAuthApiPath = "/api/1/auth"; +struct server_list { + int n; + char **servers; + char *buf; +}; + +typedef struct server_list *server_list_t; + +static void server_list_parse(server_list_t sl, const char *s) { + char *saveptr = NULL, *sp, *tok; + sl->n = 0; + sl->servers = (char **)malloc(sizeof(char *)); + sl->buf = strdup(s); + for (sp = sl->buf; (tok = strtok_r(sp, ",", &saveptr)) != NULL; sp = NULL) { + sl->servers = (char **)realloc(sl->servers, sizeof(char *) * (sl->n + 1)); + sl->servers[sl->n++] = tok; + } +} + +static void server_list_free(server_list_t sl) { + free(sl->buf); + free(sl->servers); +} + struct auth_client { CURL *c; const char *service; - const char *server; + int https; + //const char *server; + struct server_list server_list; }; -static int auth_client_set_proto(auth_client_t ac, const char *proto) { - char url[strlen(ac->server) + 32]; - sprintf(url, "%s://%s%s", proto, ac->server, kAuthApiPath); +static int auth_client_set_url(auth_client_t ac, const char *server) { + char url[strlen(server) + 32]; + sprintf(url, "http%s://%s%s", ac->https ? "s" : "", server, kAuthApiPath); CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_URL, url)); return AC_OK; } static int curl_initialized = 0; -auth_client_t auth_client_new(const char *service, const char *server) { +auth_client_t auth_client_new(const char *service, const char *servers) { auth_client_t ac = (auth_client_t)malloc(sizeof(struct auth_client)); if (!curl_initialized) { @@ -40,12 +66,12 @@ auth_client_t auth_client_new(const char *service, const char *server) { curl_initialized = 1; } + server_list_parse(&ac->server_list, servers); ac->service = service; - ac->server = server; ac->c = curl_easy_init(); + ac->https = 0; curl_easy_setopt(ac->c, CURLOPT_NOSIGNAL, 1); curl_easy_setopt(ac->c, CURLOPT_TIMEOUT, 60); - auth_client_set_proto(ac, "http"); return ac; } @@ -69,10 +95,6 @@ int auth_client_set_certificate(auth_client_t ac, if (!file_exists(ca_file) || !file_exists(crt_file) || !file_exists(key_file)) { return AC_ERR_FILE_NOT_FOUND; } - err = auth_client_set_proto(ac, "https"); - if (err != AC_OK) { - return err; - } CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_CAINFO, ca_file)); CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEYTYPE, "PEM")); CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLKEY, key_file)); @@ -81,11 +103,13 @@ int auth_client_set_certificate(auth_client_t ac, CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYPEER, 1L)); CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSL_VERIFYHOST, 0L)); CURL_CHECK(curl_easy_setopt(ac->c, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1)); + ac->https = 1; return AC_OK; } void auth_client_free(auth_client_t ac) { curl_easy_cleanup(ac->c); + server_list_free(&ac->server_list); free(ac); } @@ -102,6 +126,8 @@ const char *auth_client_strerror(int err) { return "Bad server response"; case AC_ERR_FILE_NOT_FOUND: return "Certificate or CA file not found"; + case AC_ERR_NO_SERVERS: + return "No servers could be reached"; default: return "Unknown error"; } @@ -159,9 +185,8 @@ int auth_client_authenticate(auth_client_t ac, const char *shard) { struct curl_slist *headers = NULL; struct cbuf form; - struct cbuf responsebuf; CURLcode res; - int retval; + int i, retval = AC_ERR_NO_SERVERS; // Build the POST request contents. cbuf_init(&form, 256); @@ -186,28 +211,38 @@ int auth_client_authenticate(auth_client_t ac, curl_slist_append(headers, "Content-Type: application/x-form-www-urlencoded"); curl_easy_setopt(ac->c, CURLOPT_HTTPHEADER, headers); - cbuf_init(&responsebuf, 64); - curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback); - curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf); - - res = curl_easy_perform(ac->c); - if (res == CURLE_OK) { - // Check the auth server response. - if (!strncmp(responsebuf.buf, "OK", 2)) { - retval = AC_OK; - } else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) { - retval = AC_ERR_OTP_REQUIRED; - } else if (!strncmp(responsebuf.buf, "ERROR", 5)) { - retval = AC_ERR_AUTHENTICATION_FAILURE; + // Iterate over the known servers. We create a new response buffer + // for each request just in case we get a partial transfer error. + for (i = 0; i < ac->server_list.n; i++) { + char *server = ac->server_list.servers[i]; + struct cbuf responsebuf; + + cbuf_init(&responsebuf, 64); + curl_easy_setopt(ac->c, CURLOPT_WRITEFUNCTION, responsebuf_callback); + curl_easy_setopt(ac->c, CURLOPT_WRITEDATA, (void *)&responsebuf); + auth_client_set_url(ac, server); + + res = curl_easy_perform(ac->c); + if (res == CURLE_OK) { + // Check the auth server response. + if (!strncmp(responsebuf.buf, "OK", 2)) { + retval = AC_OK; + } else if (!strncmp(responsebuf.buf, "OTP_REQUIRED", 12)) { + retval = AC_ERR_OTP_REQUIRED; + } else if (!strncmp(responsebuf.buf, "ERROR", 5)) { + retval = AC_ERR_AUTHENTICATION_FAILURE; + } else { + retval = AC_ERR_BAD_RESPONSE; + } + break; } else { - retval = AC_ERR_BAD_RESPONSE; + retval = auth_client_err_from_curl(res); } - } else { - retval = auth_client_err_from_curl(res); + + cbuf_free(&responsebuf); } cbuf_free(&form); - cbuf_free(&responsebuf); curl_slist_free_all(headers); return retval; diff --git a/pam/auth_client.h b/pam/auth_client.h index a7bfa53754337071071685854926cff176f69635..bdbe5669a82293526831a96fbf4c1e9c9ef10f92 100644 --- a/pam/auth_client.h +++ b/pam/auth_client.h @@ -11,11 +11,12 @@ typedef struct auth_client* auth_client_t; #define AC_ERR_OTP_REQUIRED -2 #define AC_ERR_BAD_RESPONSE -3 #define AC_ERR_FILE_NOT_FOUND -4 +#define AC_ERR_NO_SERVERS -5 #define AC_ERR_CURL_BASE -100 #define auth_client_err_to_curl(e) (-(e)+(AC_ERR_CURL_BASE)) #define auth_client_err_from_curl(e) ((AC_ERR_CURL_BASE)-(e)) -auth_client_t auth_client_new(const char *service, const char *server); +auth_client_t auth_client_new(const char *service, const char *servers); void auth_client_free(auth_client_t ac); const char *auth_client_strerror(int err); void auth_client_set_verbose(auth_client_t ac, int verbose); diff --git a/pam/auth_client_test.cc b/pam/auth_client_test.cc index 8f99bcae7b671817e6daeff22646e8c4ed263384..787af99f8bbbb11b7db2033bc0cf635f69050188 100644 --- a/pam/auth_client_test.cc +++ b/pam/auth_client_test.cc @@ -1,5 +1,6 @@ // Tests for auth_client.c. +#include <string> #include <stdlib.h> #include "gtest/gtest.h" extern "C" { @@ -26,16 +27,23 @@ class AuthClientTest : public ::testing::Test { public: - AuthClientTest() { - ac = auth_client_new("service", server); + AuthClientTest(std::string serverprefix) + : server_(serverprefix + std::string(server)) + { + ac = auth_client_new("service", server_.c_str()); assert(ac != NULL); auth_client_set_verbose(ac, 1); } + AuthClientTest() + : AuthClientTest("") + {} + virtual ~AuthClientTest() { auth_client_free(ac); } + std::string server_; auth_client_t ac; }; @@ -123,6 +131,28 @@ TEST_F(AuthClientTest, SSLFailsWithBadCAServerSide) { EXPECT_NE(AC_OK, result) << "authenticate() didn't fail, server=" << server; } +class AuthClientServerFallbackTest + : public AuthClientTest +{ +public: + AuthClientServerFallbackTest() + : AuthClientTest("127.8.8.8:1024,") + {} +}; + +// Test RPC fallback if first server is bad. +TEST_F(AuthClientServerFallbackTest, AuthOK) { + int result; + + result = auth_client_set_certificate(ac, ssl_ca, ssl_cert, ssl_key); + EXPECT_EQ(AC_OK, result) << "set_certificate() error: " << auth_client_strerror(result); + + result = auth_client_authenticate(ac, "user", "pass", NULL, "127.0.0.1", NULL); + EXPECT_EQ(AC_OK, result) << "authenticate() error: " << auth_client_strerror(result) + << ", server=" << server; +} + + int main(int argc, char **argv) { server = getenv("AUTH_SERVER"); if (server == NULL) {