#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include <ldap.h>
#ifdef HAVE_SASL_H
#include <sasl.h>
#endif
#ifdef HAVE_SASL_SASL_H
#include <sasl/sasl.h>
#endif
#include <string.h>
#include <sys/time.h>
#include <unistd.h>

#include <iostream>
#include <string>
#include <vector>

#include <arc/ldapquery.h>
#include <arc/notify.h>

#ifdef HAVE_LIBINTL_H
#include <libintl.h>
#define _(A) dgettext("arclib", (A))
#else
#define _(A) (A)
#endif

class sigpipe_ingore {
	public:
		sigpipe_ingore();
};

#if defined(HAVE_SASL_H) || defined(HAVE_SASL_SASL_H)
class sasl_defaults {
	public:
		sasl_defaults (ldap *ld,
		               const std::string & mech,
		               const std::string & realm,
		               const std::string & authcid,
		               const std::string & authzid,
		               const std::string & passwd);
		~sasl_defaults() {};

	private:
		std::string p_mech;
		std::string p_realm;
		std::string p_authcid;
		std::string p_authzid;
		std::string p_passwd;

	friend int my_sasl_interact(ldap *ld,
	                            unsigned int flags,
	                            void * defaults_,
	                            void * interact_);
};

static sigpipe_ingore sigpipe_ingore; 

sigpipe_ingore::sigpipe_ingore() {
	signal(SIGPIPE,SIG_IGN);
}

sasl_defaults::sasl_defaults (ldap *ld,
                              const std::string & mech,
                              const std::string & realm,
                              const std::string & authcid,
                              const std::string & authzid,
                              const std::string & passwd) : p_mech    (mech),
							    p_realm   (realm),
							    p_authcid (authcid),
							    p_authzid (authzid),
							    p_passwd  (passwd) {

	if (p_mech.empty()) {
		char * temp;
		ldap_get_option (ld, LDAP_OPT_X_SASL_MECH, &temp);
		if (temp) {
			p_mech = temp;
			free (temp);
		}
	}
	if (p_realm.empty()) {
		char * temp;
		ldap_get_option (ld, LDAP_OPT_X_SASL_REALM, &temp);
		if (temp) {
			p_realm = temp;
			free (temp);
		}
	}
	if (p_authcid.empty()) {
		char * temp;
		ldap_get_option (ld, LDAP_OPT_X_SASL_AUTHCID, &temp);
		if (temp) {
			p_authcid = temp;
			free (temp);
		}
	}
	if (p_authzid.empty()) {
		char * temp;
		ldap_get_option (ld, LDAP_OPT_X_SASL_AUTHZID, &temp);
		if (temp) {
			p_authzid = temp;
			free (temp);
		}
	}
}


int my_sasl_interact(ldap *ld,
                     unsigned int flags,
                     void * defaults_,
                     void * interact_) {

	sasl_interact_t * interact = (sasl_interact_t *) interact_;
	sasl_defaults *   defaults = (sasl_defaults *)   defaults_;

	if (flags == LDAP_SASL_INTERACTIVE) {
		notify(VERBOSE) << _("SASL Interaction") << std::endl;
	}

	while (interact->id != SASL_CB_LIST_END) {

		bool noecho = false;
		bool challenge = false;
		bool use_default = false;

		switch (interact->id) {
		case SASL_CB_GETREALM:
			if (defaults && !defaults->p_realm.empty())
				interact->defresult = strdup (defaults->p_realm.c_str());
			break;
		case SASL_CB_AUTHNAME:
			if (defaults && !defaults->p_authcid.empty())
				interact->defresult = strdup (defaults->p_authcid.c_str());
			break;
		case SASL_CB_USER:
			if (defaults && !defaults->p_authzid.empty())
				interact->defresult = strdup (defaults->p_authzid.c_str());
			break;
		case SASL_CB_PASS:
			if (defaults && !defaults->p_passwd.empty())
				interact->defresult = strdup (defaults->p_passwd.c_str());
			noecho = true;
			break;
		case SASL_CB_NOECHOPROMPT:
			noecho = true;
			challenge = true;
			break;
		case SASL_CB_ECHOPROMPT:
			challenge = true;
			break;
		}

		if (flags != LDAP_SASL_INTERACTIVE &&
		   (interact->defresult || interact->id == SASL_CB_USER)) {
				use_default = true;
		}
		else {
			if (flags == LDAP_SASL_QUIET) return 1;

			if (challenge && interact->challenge)
				notify(VERBOSE) << _("Challenge") << ": "
				                << interact->challenge << std::endl;

			if (interact->defresult)
				notify(VERBOSE) << _("Default") << ": " << interact->defresult
				              << std::endl;

			std::string prompt;
			std::string input;

			prompt = interact->prompt ?
			         std::string (interact->prompt) + ": " : "Interact: ";

			if (noecho) {
				input = getpass (prompt.c_str());
			}
			else {
			    std::cout << prompt;
			    std::cin >> input;
			}
			if (input.empty())
				use_default = true;
			else {
				interact->result = strdup (input.c_str());
				interact->len = input.length();
			}
		}

		if (use_default) {
			interact->result = strdup (interact->defresult ?
			interact->defresult : "");
			interact->len = strlen ((char *) interact->result);
		}

		if (defaults && interact->id == SASL_CB_PASS) {
			// clear default password after first use
			defaults->p_passwd = "";
		}

		interact++;
		}
	return 0;
}
#endif


LdapQuery::LdapQuery(const std::string& ldaphost,
                     int ldapport,
                     bool anonymous,
                     const std::string& usersn,
                     int timeout) {

	this->host = ldaphost;
	this->port = ldapport;
	this->anonymous = anonymous;
	this->usersn = usersn;
	this->timeout = timeout;

	this->connection = NULL;
	this->messageid = 0;
}


LdapQuery::~LdapQuery() {

	if (connection) {
		ldap_unbind(connection);
		connection = NULL;
	}
}


void LdapQuery::Connect() {

	const int version = LDAP_VERSION3;

	notify(DEBUG) << _("LdapQuery: Initializing connection to") << ": "
	              << host << ":" << port << std::endl;

	if (connection)
		throw LdapQueryError(
			_("Ldap connection already open to") + (" " + host));

	connection = ldap_init(host.c_str(), port);

	if (!connection)
		throw LdapQueryError(
			_("Could not open ldap connection to") + (" " + host));

	try {
		SetConnectionOptions(version);
	}
	catch (LdapQueryError e) {
		// Clean up and re-throw exception
		ldap_unbind(connection);
		connection = NULL;
		throw;
	}

}


void LdapQuery::SetConnectionOptions(int version) {

	timeval tout;
	tout.tv_sec = timeout;
	tout.tv_usec = 0;

	if (ldap_set_option (connection, LDAP_OPT_NETWORK_TIMEOUT, &tout) !=
	                     LDAP_OPT_SUCCESS)
		throw LdapQueryError(
			_("Could not set ldap network timeout") + (" (" + host + ")"));

	if (ldap_set_option (connection, LDAP_OPT_TIMELIMIT, &timeout) !=
	                     LDAP_OPT_SUCCESS)
		throw LdapQueryError(
			_("Could not set ldap timelimit") + (" (" + host + ")"));

	if (ldap_set_option (connection, LDAP_OPT_PROTOCOL_VERSION, &version) !=
	   LDAP_OPT_SUCCESS)
		throw LdapQueryError(
			_("Could not set ldap protocol version") + (" (" + host + ")"));

	int ldresult = 0;
	if (anonymous)
		ldresult = ldap_simple_bind_s (connection, NULL, NULL);
	else {
		int ldapflag = LDAP_SASL_QUIET;
		if (GetNotifyLevel()>=DEBUG) ldapflag = LDAP_SASL_AUTOMATIC;
#if defined(HAVE_SASL_H) || defined(HAVE_SASL_SASL_H)
		sasl_defaults defaults = sasl_defaults (connection,
		                                        SASLMECH,
		                                        "",
		                                        "",
		                                        usersn,
		                                        "");
		ldresult = ldap_sasl_interactive_bind_s (connection,
		                                         NULL,
		                                         SASLMECH,
		                                         NULL,
		                                         NULL,
		                                         ldapflag,
		                                         my_sasl_interact,
		                                         &defaults);
#else
		ldresult = ldap_simple_bind_s (connection,NULL,NULL);
#endif
	}
	if (ldresult != LDAP_SUCCESS) {
		std::string error_msg(ldap_err2string (ldresult));
		error_msg += " (" + host + ")";
		throw LdapQueryError(error_msg);
	}
	return;
}


void LdapQuery::Query(const std::string& base,
                      const std::string& filter,
                      const std::vector <std::string>& attributes,
                      Scope scope) throw (LdapQueryError) {

	Connect();

	notify(DEBUG) << _("LdapQuery: Querying") << " " << host << std::endl;

	notify(VERBOSE) << "  " << _("base dn") << ": " << base << std::endl;
	if (!filter.empty())
		notify(VERBOSE) << "  " << _("filter") << ": " << filter << std::endl;
	if (!attributes.empty()) {
		notify(VERBOSE) << "  " << _("attributes") << ":" << std::endl;
		for (std::vector<std::string>::const_iterator vs = attributes.begin();
		                                              vs != attributes.end();
		                                              vs++)
			notify(VERBOSE) << "    " << *vs << std::endl;
	}

	timeval tout;
	tout.tv_sec = timeout;
	tout.tv_usec = 0;

	char *filt = (char *) filter.c_str();

	char ** attrs;
	if (attributes.empty())
		attrs = NULL;
	else {
		attrs = new char * [attributes.size() + 1];
		int i = 0;
		for (std::vector<std::string>::const_iterator vs = attributes.begin();
		     vs != attributes.end(); vs++, i++)
			attrs [i] = (char *) vs->c_str();
		attrs [i] = NULL;
	}

	int ldresult = ldap_search_ext (connection,
	                                base.c_str(),
	                                scope,
	                                filt,
	                                attrs,
	                                0,
	                                NULL,
	                                NULL,
	                                &tout,
	                                0,
	                                &messageid);

	if (attrs) delete[] attrs; // Okay to delete null pointers..

	if (ldresult != LDAP_SUCCESS) {
		std::string error_msg(ldap_err2string (ldresult));
		error_msg += " (" + host + ")";
		ldap_unbind (connection);
		connection = NULL;
		throw LdapQueryError(error_msg);
	}
}


void LdapQuery::Result(ldap_callback callback, void* ref)
throw(LdapQueryError) {

	try {
		HandleResult(callback, ref);
	} catch (LdapQueryError e) {
		// Clean up and re-throw exception
		ldap_unbind (connection);
		connection = NULL;
		messageid = 0;
		throw;
	}
	// Since C++ doesnt have finally(), here we are again
	ldap_unbind (connection);
	connection = NULL;
	messageid = 0;

	return;
}


void LdapQuery::HandleResult(ldap_callback callback, void* ref) {

	notify(DEBUG) << _("LdapQuery: Getting results from") << " " << host
	             << std::endl;

	if (!messageid)
		throw LdapQueryError(
			_("Error: no ldap query started to") + (" " + host));

	timeval tout;
	tout.tv_sec = timeout;
	tout.tv_usec = 0;

	bool done = false;
	int ldresult = 0;
	LDAPMessage * res = NULL;

	while (!done && (ldresult = ldap_result (connection,
	                                         messageid,
	                                         LDAP_MSG_ONE,
	                                         &tout,
	                                         &res)) > 0) {
		for (LDAPMessage * msg = ldap_first_message (connection, res); msg;
		     msg = ldap_next_message (connection, msg)) {

			switch (ldap_msgtype(msg)) {
			case LDAP_RES_SEARCH_ENTRY:
				HandleSearchEntry(msg, callback, ref);
			break;
			case LDAP_RES_SEARCH_RESULT:
				done = true;
			break;
			} // switch
		} // for
		ldap_msgfree (res);
	}

	if (ldresult == 0)
		throw LdapQueryError(_("Ldap query timed out") + (": " + host));

	if (ldresult == -1) {
		std::string error_msg(ldap_err2string (ldresult));
		error_msg += " (" + host + ")";
		throw LdapQueryError(error_msg);
	}

	return;
}


void LdapQuery::HandleSearchEntry(LDAPMessage* msg,
                                  ldap_callback callback,
                                  void* ref) {
	char *dn = ldap_get_dn(connection, msg);
	callback("dn", dn, ref);
	if (dn) ldap_memfree(dn);

	BerElement *ber = NULL;
	for (char *attr = ldap_first_attribute (connection, msg, &ber);
	     attr; attr = ldap_next_attribute (connection, msg, ber)) {
		BerValue **bval;
		if ((bval = ldap_get_values_len (connection, msg, attr))) {
			for (int i = 0; bval[i]; i++) {
				callback (attr,	(bval[i]->bv_val ? bval[i]->bv_val : ""), ref);
			}
			ber_bvecfree(bval);
		}
		ldap_memfree(attr);
	}
	if (ber)
		ber_free(ber, 0);

}


std::string LdapQuery::Host() {
	return host;
}


ParallelLdapQueries::ParallelLdapQueries(std::list<URL> clusters,
                                         std::string filter,
                                         std::vector<std::string> attrs,
                                         ldap_callback callback,
                                         void* object,
                                         LdapQuery::Scope scope,
                                         const std::string& usersn,
                                         bool anonymous,
                                         int timeout) {

	this->clusters = clusters;
	this->filter = filter;
	this->attrs = attrs;
	this->callback = callback;
	this->object = object;
	this->scope = scope;
	this->usersn = usersn;
	this->anonymous = anonymous;
	this->timeout = timeout;

	this->urlit = (this->clusters).begin();
	pthread_mutex_init(&lock, NULL);
}


ParallelLdapQueries::~ParallelLdapQueries() {
	pthread_mutex_destroy(&lock);
}


void ParallelLdapQueries::Query() {
	const int numqueries = clusters.size();
	pthread_t threads[numqueries];
	int res;

	for (unsigned int i = 0; i<clusters.size(); i++) {
		res = pthread_create(&threads[i],
		                     NULL,
		                     ParallelLdapQueries::DoLdapQuery,
		                     (void*)this);
		if (res!=0)
			throw LdapQueryError(
				_("Thread creation in ParallelLdapQueries failed"));
	}

	void* result;
	for (unsigned int i = 0; i<clusters.size(); i++) {
		res = pthread_join(threads[i], &result);
		if (res!=0)
			throw LdapQueryError(
				_("Thread joining in ParallelLdapQueries failed"));
	}
}


void* ParallelLdapQueries::DoLdapQuery(void* arg) {
	ParallelLdapQueries* plq = (ParallelLdapQueries*)arg;

	pthread_mutex_lock(&plq->lock);
	URL qurl = *(plq->urlit);
	plq->urlit++;
	pthread_mutex_unlock(&plq->lock);

	LdapQuery ldapq(qurl.Host(),
	                qurl.Port(),
	                plq->anonymous,
	                plq->usersn,
	                plq->timeout);

	try {
		ldapq.Query(qurl.BaseDN(), plq->filter, plq->attrs, plq->scope);
	} catch (LdapQueryError e) {
		notify(DEBUG) << _("Warning") << ": " << e.what() << std::endl;
		pthread_exit(NULL);
	}	

	pthread_mutex_lock(&plq->lock);

	try {
		ldapq.Result(plq->callback, plq->object);
	} catch (LdapQueryError e) {
		notify(DEBUG) << _("Warning") << ": " << e.what() << std::endl;
	}

	pthread_mutex_unlock(&plq->lock);	
	pthread_exit(NULL);
}
