#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifndef _BSD_SOURCE
#define _BSD_SOURCE
#endif

#include <iostream>
#include <string>
#include <vector>
#include <string.h>
#include <unistd.h>
#include <sys/time.h>

#include <ldap.h>
#ifdef HAVE_SASL_H
#include <sasl.h>
#endif
#ifdef HAVE_SASL_SASL_H
#include <sasl/sasl.h>
#endif

#include "LdapQuery.h"


#if defined(HAVE_SASL_H) || defined(HAVE_SASL_SASL_H)
#define SASLMECH "GSI-GSSAPI"


class sasl_defaults {

public:
  sasl_defaults (LDAP * ld,
		 const std::string & mech,
		 const std::string & realm,
		 const std::string & authcid,
		 const std::string & authzid,
		 const std::string & passwd);
  ~sasl_defaults() {};

private:
  std::string p_mech;
  std::string p_realm;
  std::string p_authcid;
  std::string p_authzid;
  std::string p_passwd;

  friend int my_sasl_interact (LDAP * ld,
			       unsigned int flags,
			       void * defaults_,
			       void * interact_);
};


sasl_defaults::sasl_defaults (LDAP * ld,
			      const std::string & mech,
			      const std::string & realm,
			      const std::string & authcid,
			      const std::string & authzid,
			      const std::string & passwd) : p_mech    (mech),
						       p_realm   (realm),
						       p_authcid (authcid),
						       p_authzid (authzid),
						       p_passwd  (passwd) {

  if (p_mech.empty()) {
    char * temp;
    ldap_get_option (ld, LDAP_OPT_X_SASL_MECH, &temp);
    if (temp) { p_mech = temp; free (temp); }
  }
  if (p_realm.empty()) {
    char * temp;
    ldap_get_option (ld, LDAP_OPT_X_SASL_REALM, &temp);
    if (temp) { p_realm = temp; free (temp); }
  }
  if (p_authcid.empty()) {
    char * temp;
    ldap_get_option (ld, LDAP_OPT_X_SASL_AUTHCID, &temp);
    if (temp) { p_authcid = temp; free (temp); }
  }
  if (p_authzid.empty()) {
    char * temp;
    ldap_get_option (ld, LDAP_OPT_X_SASL_AUTHZID, &temp);
    if (temp) { p_authzid = temp; free (temp); }
  }
}


int my_sasl_interact (LDAP * ld,
		      unsigned int flags,
		      void * defaults_,
		      void * interact_) {

  sasl_interact_t * interact = (sasl_interact_t *) interact_;
  sasl_defaults *   defaults = (sasl_defaults *)   defaults_;

  if (flags == LDAP_SASL_INTERACTIVE) {
    std::cerr << "SASL Interaction" << std::endl;
  }

  while (interact->id != SASL_CB_LIST_END) {

    bool noecho = false;
    bool challenge = false;
    bool use_default = false;

    switch (interact->id) {
    case SASL_CB_GETREALM:
      if (defaults && !defaults->p_realm.empty())
	interact->defresult = strdup (defaults->p_realm.c_str());
      break;
    case SASL_CB_AUTHNAME:
      if (defaults && !defaults->p_authcid.empty())
	interact->defresult = strdup (defaults->p_authcid.c_str());
      break;
    case SASL_CB_USER:
      if (defaults && !defaults->p_authzid.empty())
	interact->defresult = strdup (defaults->p_authzid.c_str());
      break;
    case SASL_CB_PASS:
      if (defaults && !defaults->p_passwd.empty())
	interact->defresult = strdup (defaults->p_passwd.c_str());
      noecho = true;
      break;
    case SASL_CB_NOECHOPROMPT:
      noecho = true;
      challenge = true;
      break;
    case SASL_CB_ECHOPROMPT:
      challenge = true;
      break;
    }

    if (flags != LDAP_SASL_INTERACTIVE &&
	(interact->defresult || interact->id == SASL_CB_USER)) {
      use_default = true;
    }

    else {

      if (flags == LDAP_SASL_QUIET) return 1;

      if (challenge && interact->challenge)
	std::cerr << "Challenge: " << interact->challenge << std::endl;

      if (interact->defresult)
	std::cerr << "Default: " << interact->defresult << std::endl;

      std::string prompt;
      std::string input;

      prompt = interact->prompt ?
	std::string (interact->prompt) + ": " : "Interact: ";

      if (noecho) {
	input = getpass (prompt.c_str());
      }
      else {
	std::cerr << prompt;
	std::cin >> input;
      }
      if (input.empty())
	use_default = true;
      else {
	interact->result = strdup (input.c_str());
	interact->len = input.length();
      }
    }

    if (use_default) {
      interact->result = strdup (interact->defresult ?
				 interact->defresult : "");
      interact->len = strlen ((char *) interact->result);
    }

    if (defaults && interact->id == SASL_CB_PASS) {
      // clear default password after first use
      defaults->p_passwd = "";
    }

    interact++;
  }

  return 0;
}
#endif


LdapQuery::LdapQuery () : connection (NULL), messageid (0) {}


int LdapQuery::Connect (const std::string & ldaphost,
			int ldapport,
			const std::string & usersn,
			bool anonymous,
			int timeout,
			int debug) {

  host = ldaphost;
  port = ldapport;

  const int debuglevel = 255;
  const int version = LDAP_VERSION3;

  if (debug) std::cout << "Initializing LDAP connection to " << host << std::endl;

  if (debug > 2) {
    if (ber_set_option (NULL, LBER_OPT_DEBUG_LEVEL, &debuglevel) !=
	LBER_OPT_SUCCESS)
      std::cerr << "Warning: Could not set LBER_OPT_DEBUG_LEVEL " << debuglevel
	   << " (" << host << ")" << std::endl;
    if (ldap_set_option (NULL, LDAP_OPT_DEBUG_LEVEL, &debuglevel) !=
	LDAP_OPT_SUCCESS)
      std::cerr << "Warning: Could not set LDAP_OPT_DEBUG_LEVEL " << debuglevel
	   << " (" << host << ")" << std::endl;
  }

  if (connection) {
    std::cerr << "Error: LDAP connection to " << host << " already open" << std::endl;
    goto errorexit;
  }

  connection = ldap_init (host.c_str(), port);

  if (!connection) {
    std::cerr << "Warning: Could not open LDAP connection to " << host << std::endl;
    goto errorexit;
  }

  timeval tout;
  tout.tv_sec = timeout;
  tout.tv_usec = 0;
  if (ldap_set_option (connection, LDAP_OPT_NETWORK_TIMEOUT, &tout)
      != LDAP_OPT_SUCCESS) {
    std::cerr << "Error: Could not set LDAP network timeout" << " (" << host << ")"
	 << std::endl;
    goto errorexit;
  }

  if (ldap_set_option (connection, LDAP_OPT_TIMELIMIT, &timeout)
      != LDAP_OPT_SUCCESS) {
    std::cerr << "Error: Could not set LDAP timelimit" << " (" << host << ")"
	 << std::endl;
    goto errorexit;
  }

  if (ldap_set_option (connection, LDAP_OPT_PROTOCOL_VERSION, &version) !=
      LDAP_OPT_SUCCESS) {
    std::cerr << "Error: Could not set LDAP protocol version" << " (" << host << ")"
	 << std::endl;
    goto errorexit;
  }

  int ldresult;
  if (anonymous)
    ldresult = ldap_simple_bind_s (connection, NULL, NULL);
  else {
    int ldapflag = (debug > 1) ? LDAP_SASL_AUTOMATIC : LDAP_SASL_QUIET;
#if defined(HAVE_SASL_H) || defined(HAVE_SASL_SASL_H)
    sasl_defaults defaults = sasl_defaults (connection, SASLMECH,
					    "", "", usersn, "");
    ldresult = ldap_sasl_interactive_bind_s (connection, NULL, SASLMECH, NULL,
					     NULL, ldapflag, my_sasl_interact,
					     &defaults);
#else
    ldresult = ldap_simple_bind_s (connection, NULL, NULL);
#endif
  }
  if (ldresult != LDAP_SUCCESS) {
    std::cerr << "Warning: " << ldap_err2string (ldresult) << " (" << host << ")"
	 << std::endl;
    goto errorexit;
  }

  return 0;

 errorexit:
  if (connection) {
    ldap_unbind (connection);
    connection = NULL;
  }
  return 1;
}


int LdapQuery::Query (const std::string & base,
		      const std::string & filter,
		      const std::vector <std::string> & attributes,
		      Scope scope,
		      int timeout,
		      int debug) {

  if (debug) std::cout << "Initializing LDAP query to " << host << std::endl;
  if (debug > 1) {
    std::cout << "  base dn: " << base << std::endl;
    if (!filter.empty())
      std::cout << "  filter: " << filter << std::endl;
    if (!attributes.empty()) {
      std::cout << "  attributes:" << std::endl;
      for (std::vector <std::string>::const_iterator vsi = attributes.begin();
	   vsi != attributes.end(); vsi++)
	std::cout << "    " << *vsi << std::endl;
    }
  }

  if (!connection) {
    std::cerr << "Warning: no LDAP connection to " << host << std::endl;
    return 1;
  }

  timeval tout;
  tout.tv_sec = timeout;
  tout.tv_usec = 0;

  char * filt;
  if (filter.empty())
    filt = NULL;
  else
    filt = (char *) filter.c_str();

  char ** attrs;
  if (attributes.empty())
    attrs = NULL;
  else {
    attrs = new char * [attributes.size() + 1];
    int i = 0;
    for (std::vector <std::string>::const_iterator vsi = attributes.begin();
	 vsi != attributes.end(); vsi++, i++)
      attrs [i] = (char *) vsi->c_str();
    attrs [i] = NULL;
  }

  int ldresult = ldap_search_ext (connection, base.c_str(), scope, filt,
				  attrs, 0, NULL, NULL, &tout, 0, &messageid);

  if (attrs) delete attrs;

  if (ldresult != LDAP_SUCCESS) {
    std::cerr << "Warning: " << ldap_err2string (ldresult) << " (" << host << ")"
	 << std::endl;
    ldap_unbind (connection);
    connection = NULL;
    return 1;
  }
  return 0;
}


int LdapQuery::Result (void callback (const std::string & attr,
				      const std::string & value,
				      void * ref),
		       void * ref,
		       int timeout,
		       int debug) {

  if (debug) std::cout << "Getting LDAP query results from " << host << std::endl;

  if (!connection) {
    std::cerr << "Warning: no LDAP connection to " << host << std::endl;
    return 1;
  }

  if (!messageid) {
    std::cerr << "Error: no LDAP query started to " << host << std::endl;
    return 1;
  }

  timeval tout;
  tout.tv_sec = timeout;
  tout.tv_usec = 0;

  bool done = false;
  int ldresult;
  LDAPMessage * res = NULL;

  while (!done && (ldresult = ldap_result (connection, messageid,
					   LDAP_MSG_ONE, &tout, &res)) > 0) {
    for (LDAPMessage * msg = ldap_first_message (connection, res); msg;
	 msg = ldap_next_message (connection, msg)) {
      BerElement * ber = NULL;
      switch (ldap_msgtype (msg)) {
      char * dn;
      case LDAP_RES_SEARCH_ENTRY:
	dn = ldap_get_dn (connection, msg);
	callback ("dn", dn, ref);
	if (dn) ldap_memfree (dn);
	for (char * attr = ldap_first_attribute (connection, msg, &ber); attr;
	     attr = ldap_next_attribute (connection, msg, ber)) {
	  BerValue ** bval;
	  if (bval = ldap_get_values_len (connection, msg, attr)) {
	    for (int i = 0; bval[i]; i++)
	      callback (attr, (bval[i]->bv_val ? bval[i]->bv_val : ""), ref);
	    ber_bvecfree (bval);
	  }
	  ldap_memfree (attr);
	}
	if (ber) ber_free (ber, 0);
	break;
      case LDAP_RES_SEARCH_RESULT:
	done = true;
	break;
      }
    }
    ldap_msgfree (res);
  }

  int retval = 0;

  if (ldresult == 0) {
    std::cerr << "Warning: LDAP query to " << host << " timed out" << std::endl;
    retval = 1;
  }

  if (ldresult == -1) {
    std::cerr << "Warning: " << ldap_err2string (ldresult) << " (" << host << ")"
	 << std::endl;
    retval = 1;
  }

  ldap_unbind (connection);
  connection = NULL;
  messageid = 0;
  return retval;
}


int LdapQuery::Find (const std::string & ldaphost,
		     int ldapport,
		     const std::string & usersn,
		     const std::string & base,
		     const std::string & filter,
		     const std::vector <std::string> & attributes,
		     Scope scope,
		     void callback (const std::string & attr,
				    const std::string & value,
				    void * ref),
		     void * ref,
		     bool anonymous,
		     int timeout,
		     int debug) {

  if (Connect (ldaphost, ldapport, usersn,
	       anonymous, timeout, debug)) return 1;
  if (Query (base, filter, attributes, scope, timeout, debug)) return 1;
  return Result (callback, ref, timeout, debug);
}
