Skip to content
zero-pinyin-service.c 7.04 KiB
Newer Older
#include "zero-pinyin-service.h"
#include "parse-pinyin.h"

void
get_candidates_test (const char* preedit_str,
		     const guint fetch_size,
		     GVariantBuilder *candidates_builder,
		     GVariantBuilder *matched_lengths_builder)
{
	if (g_str_equal (preedit_str, "liyifeng")) {
		const gchar *matches[] = {"李易峰", "利益", "礼仪", "离异", "里", "理", "力"};
		guint matched_lengths[] = {8, 4, 4, 4, 2, 2, 2};
		for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
			g_variant_builder_add (candidates_builder, "s", matches[i]);
			g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
		}
	} else if (g_str_equal (preedit_str, "feng")) {
		const gchar *matches[] = {"风", "封", "疯", "丰", "凤"};
		guint matched_lengths[] = {4, 4, 4, 4, 4, 4};
		for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
			g_variant_builder_add (candidates_builder, "s", matches[i]);
			g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
		}
	} else if (g_str_equal (preedit_str, "yifeng")) {
		const gchar *matches[] = {"一封", "遗风", "艺", "依", "一", "以"};
		guint matched_lengths[] = {6, 6, 2, 2, 2, 2};
		for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
			g_variant_builder_add (candidates_builder, "s", matches[i]);
			g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
		}
	}
}

/**
 * build where clause for build_sql_for_n_pinyin().
 *
 * @pylist: the pinyin list.
 * @n: number of Pinyin to use in pylist.
 *
 * returns: where_clause, caller should g_free() result after use.
 */
static char*
build_where_clause (GList* pylist,
		    const guint n)
{
	GString* s = NULL;
	GList* iter = pylist;
	gboolean first_condition_done = FALSE;
	Pinyin* thispy = NULL;
	s = g_string_new ("");
	for (guint i = 0; i < n; ++i) {
		g_assert_nonnull (iter);
		thispy = (Pinyin*) iter->data;
		if (thispy->shengmu_i) {
			if (G_LIKELY (first_condition_done)) {
				g_string_append_printf (s, "AND s%u=%d ", i, thispy->shengmu_i);
			} else {
				g_string_append_printf (s, "s%u=%d ", i, thispy->shengmu_i);
				first_condition_done = TRUE;
			}
		}
		if (thispy->yunmu_i) {
			if (G_LIKELY (first_condition_done)) {
				g_string_append_printf (s, "AND y%u=%d ", i, thispy->yunmu_i);
			} else {
				g_string_append_printf (s, "y%u=%d ", i, thispy->yunmu_i);
				first_condition_done = TRUE;
			}
		}
		iter = iter->next;
	}
	gchar* result = s->str;
	g_string_free (s, FALSE);
	return result;
}

/**
 * build a SQL to query candidates for first n pinyin in pylist.
 * n can be from 1 to len(pylist).
 *
 * caller should free result with g_free() after use.
 */
static char*
build_sql_for_n_pinyin (GList* pylist,
			const guint n,
			const guint limit)
{
	GString* sql = NULL;
	gchar* where_clause = NULL;
	sql = g_string_new ("SELECT user_freq, phrase, freq FROM (");
	g_string_append_printf (
		sql, "SELECT 0 AS user_freq, phrase, freq FROM "
		"maindb.py_phrase_%u WHERE ", n - 1);
	where_clause = build_where_clause (pylist, n);
	g_assert_nonnull (where_clause);
	g_message ("where_clause=%s", where_clause);
	sql = g_string_append (sql, where_clause);
	sql = g_string_append (sql, " UNION ");
	g_string_append_printf (
		sql, "SELECT user_freq, phrase, freq FROM "
		"userdb.py_phrase_%u WHERE ", n - 1);
	sql = g_string_append (sql, where_clause);
	sql = g_string_append (sql, ") GROUP BY phrase ORDER BY user_freq DESC, freq DESC ");
	g_string_append_printf (sql, "LIMIT %u", limit);
	char* result = sql->str;
	g_free (where_clause);
	g_string_free (sql, FALSE);
	return result;
}

/**
 * fetch candidates for a fixed word length.
 *
 * @pylist: the pinyin list.
 * @group_size: the fixed word length. use this many pinyin from pinyin list.
 * @limit: fetch this many result is enough for user. more is not a problem though.
 * @candidates: the result candidate list. caller should free this after use.
 *
 * returns: how many candidates fetched.
 */
static guint
get_candidates_for_n_pinyin (sqlite3* db,
			     GList* pylist,
			     const guint group_size,
			     const guint limit,
			     GList** candidates)
{
	const guint DEFAULT_LIMIT = 50;
	GList* result = NULL;	/* GList of Candidate */

	g_assert_cmpint (group_size, >=, 1);
	g_assert_cmpint (group_size, <=, g_list_length (pylist));

	gint candidates_count = 0;
	gint r = 0;
	/* build SQL and run SQL query */
	char* sql = NULL;
	sql = build_sql_for_n_pinyin (pylist, group_size, MAX (limit, DEFAULT_LIMIT));
	g_message ("build_sql_for_n_pinyin result SQL:\n\n%s\n", sql);

	guint matched_py_length = 0;
	GList* iter = pylist;
	for (guint i = 0; i < group_size; ++i) {
		matched_py_length += ((Pinyin*) iter->data)->length;
		iter = iter->next;
	}

	sqlite3_stmt* stmt = NULL;
	const char* unused;
	Candidate* c = NULL;
	r = sqlite3_prepare_v2 (db, sql, -1, &stmt, &unused);
	g_free (sql);
	g_assert_nonnull (unused);
	g_assert_cmpstr (unused, ==, "");
	if (strlen (unused)) {
		g_warning ("part of sql is unused \"%s\" length=%zu",
			   unused, strlen (unused));
	}
	while (TRUE) {
		r = sqlite3_step (stmt);
		if (r == SQLITE_DONE) {
			break;
		} else if (r == SQLITE_ROW) {
			c = g_new (Candidate, 1);
			/* sql SELECT should select these 3 columns in order */
			c->user_freq = sqlite3_column_int (stmt, 0);
			c->str = g_strdup ((const char*) sqlite3_column_text (stmt, 1));
			c->freq = sqlite3_column_int (stmt, 2);
			c->matched_py_length = matched_py_length;

			if (g_utf8_validate (c->str, -1, NULL)) {
				result = g_list_prepend (result, c);
				candidates_count++;
			} else {
				g_warning ("ignore non utf8 phrase: %s", c->str);
			}
		} else if (r == SQLITE_BUSY) {
			g_warning ("sqlite3_step got SQLITE_BUSY");
			break;
		} else {
			g_warning ("sqlite3_step error: %d (%s)",
				   r, sqlite3_errmsg (db));
			break;
		}
	}
	r = sqlite3_finalize (stmt);
	if (r != SQLITE_OK) {
		g_debug ("sqlite3_finalize error: %d (%s)", r, sqlite3_errmsg (db));
	}

	/* store query result in a new GList */
	*candidates = g_list_reverse (result);
	return candidates_count;
}

void
get_candidates (sqlite3* db,
		const char* preedit_str,
		const guint fetch_size,
		GVariantBuilder *candidates_builder,
		GVariantBuilder *matched_lengths_builder)
{
	GList* pylist = NULL;
	guint pylist_len = 0;

	pylist = parse_pinyin (preedit_str, 15);
	pylist_len = g_list_length (pylist);

	guint group_size = pylist_len;
	guint fetched_size = 0;
	guint r = 0;
	GList* candidates = NULL;
	while (fetched_size < fetch_size && group_size > 0) {
		g_message ("phrase length=%u", group_size);
		r = get_candidates_for_n_pinyin (db, pylist, group_size, fetch_size - fetched_size, &candidates);
		if (candidates) {
			GList* iter = g_list_first (candidates);
			Candidate* c = NULL;
			while (iter != NULL) {
				c = (Candidate*) iter->data;
				g_variant_builder_add (candidates_builder, "s",
						       c->str);
				g_free (c->str);
				g_variant_builder_add (matched_lengths_builder, "u",
						       c->matched_py_length);
				iter = iter->next;
			}
			g_list_free_full (candidates, g_free);
		}
		g_message ("%u candidates found", r);
		fetched_size += r;
		group_size--;
	}
	g_message ("returning %u candidates", fetched_size);
	g_list_free_full (pylist, g_free);
}