mysql.c 7.44 KB
/* GNU Mailutils -- a suite of utilities for electronic mail
   Copyright (C) 2004, 2005 Free Software Foundation, Inc.

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Lesser General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.

   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Lesser General Public License for more details.

   You should have received a copy of the GNU Lesser General Public
   License along with this library; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA  */

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif

#include <mailutils/mailutils.h>
#include <mailutils/sql.h>

#include <mysql/mysql.h>
#include <mysql/errmsg.h>

struct mu_mysql_data
{
  MYSQL *mysql;
  MYSQL_RES  *result;
};
  

static int 
do_mysql_query (mu_sql_connection_t conn, char *query)
{
  int rc;
  int i;
  MYSQL *mysql;

  for (i = 0; i < 10; i++)
    {
      mysql = ((struct mu_mysql_data*)conn->data)->mysql;
      rc = mysql_query (mysql, query);
      if (rc && mysql_errno (mysql) == CR_SERVER_GONE_ERROR)
	{
	  /* Reconnect? */
	  mu_sql_disconnect (conn);
	  mu_sql_connect (conn);
	  continue;
	}
      break;
    }
  return rc;
}

/* ************************************************************************* */
/* Interface routines */

static int
init (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = calloc (1, sizeof (*mp));
  if (!mp)
    return ENOMEM;
  conn->data = mp;
  return 0;
}

static int
destroy (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  free (mp->mysql);
  free (mp);
  conn->data = NULL;
  return 0;
}
  

static int
connect (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  char *host, *socket_name;
  
  mp->mysql = malloc (sizeof(MYSQL));
  if (!mp->mysql)
    return ENOMEM;
  
  mysql_init (mp->mysql);

  if (conn->server[0] == '/')
    {
      host = "localhost";
      socket_name = conn->server;
    }
  else
    host = conn->server;
  
  if (!mysql_real_connect(mp->mysql, 
			  host,
			  conn->login,
			  conn->password,
			  conn->dbname,
			  conn->port,
			  socket_name,
			  0))
    return MU_ERR_SQL;
  
  return 0;
}

static int 
disconnect (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  
  mysql_close (mp->mysql);
  free (mp->mysql);
  mp->mysql = NULL;  
  return 0;
}

static int
query (mu_sql_connection_t conn, char *query)
{
  if (do_mysql_query (conn, query)) 
    return MU_ERR_SQL;
  return 0;
}


static int
store_result (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  if (!(mp->result = mysql_store_result (mp->mysql)))
    {
      if (mysql_errno (mp->mysql))
	return MU_ERR_SQL;
      return MU_ERR_NO_RESULT;
    }
  return 0;
}

static int
release_result (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  mysql_free_result (mp->result);
  return 0;
}

static int
num_columns (mu_sql_connection_t conn, size_t *np)
{
  struct mu_mysql_data *mp = conn->data;
  *np = mysql_num_fields (mp->result);
  return 0;
}

static int
num_tuples (mu_sql_connection_t conn, size_t *np)
{
  struct mu_mysql_data *mp = conn->data;
  *np = mysql_num_rows (mp->result);
  return 0;
}

static int
get_column (mu_sql_connection_t conn, size_t nrow, size_t ncol, char **pdata)
{
  struct mu_mysql_data *mp = conn->data;
  MYSQL_ROW row;

  if (nrow >= mysql_num_rows (mp->result)
      || ncol >= mysql_num_fields (mp->result))
    return MU_ERR_BAD_COLUMN;
  
  mysql_data_seek (mp->result, nrow);
  row = mysql_fetch_row (mp->result);
  if (!row)
    return MU_ERR_BAD_COLUMN;
  *pdata = row[ncol];
  return 0;
}

static const char *
errstr (mu_sql_connection_t conn)
{
  struct mu_mysql_data *mp = conn->data;
  return mysql_error (mp->mysql);
}


/* MySQL scrambled password support */

/* Convert a single hex digit to corresponding number */
static unsigned 
digit_to_number (char c)
{
  return (unsigned) (c >= '0' && c <= '9' ? c-'0' :
                     c >= 'A' && c <= 'Z' ? c-'A'+10 :
                     c-'a'+10);
}

/* Extract salt value from MySQL scrambled password.
   
   WARNING: The code assumes that
       1. strlen (password) % 8 == 0
       2. number_of_entries (RES) = strlen (password) / 8

   For MySQL >= 3.21, strlen(password) == 16 */
static void
get_salt_from_scrambled (unsigned long *res, const char *password)
{
  res[0] = res[1] = 0;
  while (*password)
    {
      unsigned long val = 0;
      unsigned i;

      for (i = 0; i < 8 ; i++)
        val = (val << 4) + digit_to_number (*password++);
      *res++ = val;
    }
}

/* Scramble a plaintext password */
static void
scramble_password (unsigned long *result, const char *password)
{
  unsigned long nr = 1345345333L, add = 7, nr2 = 0x12345671L;
  unsigned long tmp;

  for (; *password ; password++)
    {
      if (*password == ' ' || *password == '\t')
        continue;                   
      tmp = (unsigned long) (unsigned char) *password;
      nr ^= (((nr & 63) + add) * tmp)+ (nr << 8);
      nr2 += (nr2 << 8) ^ nr;
      add += tmp;
    }

  result[0] = nr & (((unsigned long) 1L << 31) -1L);
  result[1] = nr2 & (((unsigned long) 1L << 31) -1L);
}

#if 0
static void
octet2hex (char *to, const unsigned char *str, unsigned len)
{
  const char *str_end= str + len;
  static char d[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";

  for ( ; str != str_end; ++str)
    {
      *to++ = d[(*str & 0xF0) >> 4];
      *to++ = d[*str & 0x0F];
    }
  *to= '\0';
}

#define SHA1_HASH_SIZE 20
int
mu_check_mysql_4x_password (const char *scrambled, const char *message)
{
  struct sha1_ctx sha1_context;
  uint8 hash_stage2[SHA1_HASH_SIZE];
  char to[2*SHA1_HASH_SIZE + 2];

  if (!to)
    return 1;
  
  /* stage 1: hash password */
  sha1_init_ctx (&sha1_context);
  sha1_process_bytes (message, strlen (message), &sha1_context);
  sha1_finish_ctx (&sha1_context, to);

  /* stage 2: hash stage1 output */
  sha1_init_ctx (&sha1_context);
  sha1_process_bytes (to, SHA1_HASH_SIZE, &sha1_context);
  sha1_finish_ctx (&sha1_context, hash_stage2);

  /* convert hash_stage2 to hex string */
  *to++= '*';
  octet2hex (to, hash_stage2, SHA1_HASH_SIZE);

  /* Compare both strings */
  return memcmp (to, scrambled, strlen (scrambled));
}
#endif

/* Check whether a plaintext password MESSAGE matches MySQL scrambled password
   PASSWORD */
int
mu_check_mysql_scrambled_password (const char *scrambled, const char *message)
{
  unsigned long hash_pass[2], hash_message[2];
  char buf[17];

  if (strlen (scrambled) < 16)
    return 1;
  if (strlen (scrambled) > 16)
    {
      const char *p;
      /* Try to normalize it by cutting off trailing whitespace */
      for (p = scrambled + strlen (scrambled) - 1;
	   p > scrambled && isspace (*p); p--)
	;
      if (p - scrambled != 15)
	return 1;
      memcpy (buf, scrambled, 16);
      buf[17] = 0;
      scrambled = buf;
    }
  
  get_salt_from_password (hash_pass, scrambled);
  scramble_password (hash_message, message);
  return !(hash_message[0] == hash_pass[0]
	   && hash_message[1] == hash_pass[1]);
}


/* Register module */
MU_DECL_SQL_DISPATCH_T(mysql) = {
  "mysql",
  3306,
  init,
  destroy,
  connect,
  disconnect,
  query,
  store_result,
  release_result,
  num_tuples,
  num_columns,
  get_column,
  errstr,
};