Newer
Older
#include "zero-pinyin-service.h"
#include "parse-pinyin.h"
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
void
get_candidates_test (const char* preedit_str,
const guint fetch_size,
GVariantBuilder *candidates_builder,
GVariantBuilder *matched_lengths_builder)
{
if (g_str_equal (preedit_str, "liyifeng")) {
const gchar *matches[] = {"李易峰", "利益", "礼仪", "离异", "里", "理", "力"};
guint matched_lengths[] = {8, 4, 4, 4, 2, 2, 2};
for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
g_variant_builder_add (candidates_builder, "s", matches[i]);
g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
}
} else if (g_str_equal (preedit_str, "feng")) {
const gchar *matches[] = {"风", "封", "疯", "丰", "凤"};
guint matched_lengths[] = {4, 4, 4, 4, 4, 4};
for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
g_variant_builder_add (candidates_builder, "s", matches[i]);
g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
}
} else if (g_str_equal (preedit_str, "yifeng")) {
const gchar *matches[] = {"一封", "遗风", "艺", "依", "一", "以"};
guint matched_lengths[] = {6, 6, 2, 2, 2, 2};
for (guint i = 0; i < G_N_ELEMENTS (matches); ++i) {
g_variant_builder_add (candidates_builder, "s", matches[i]);
g_variant_builder_add (matched_lengths_builder, "u", matched_lengths[i]);
}
}
}
/**
* build where clause for build_sql_for_n_pinyin().
*
* @pylist: the pinyin list.
* @n: number of Pinyin to use in pylist.
*
* returns: where_clause, caller should g_free() result after use.
*/
static char*
build_where_clause (GList* pylist,
const guint n)
{
GString* s = NULL;
GList* iter = pylist;
gboolean first_condition_done = FALSE;
Pinyin* thispy = NULL;
for (guint i = 0; i < n; ++i) {
g_assert_nonnull (iter);
thispy = (Pinyin*) iter->data;
if (thispy->shengmu_i) {
if (G_LIKELY (first_condition_done)) {
g_string_append_printf (s, "AND s%u=%d ", i, thispy->shengmu_i);
} else {
g_string_append_printf (s, "s%u=%d ", i, thispy->shengmu_i);
first_condition_done = TRUE;
}
}
if (thispy->yunmu_i) {
if (G_LIKELY (first_condition_done)) {
g_string_append_printf (s, "AND y%u=%d ", i, thispy->yunmu_i);
} else {
g_string_append_printf (s, "y%u=%d ", i, thispy->yunmu_i);
first_condition_done = TRUE;
}
}
iter = iter->next;
}
gchar* result = s->str;
g_string_free (s, FALSE);
return result;
}
/**
* return a string like ", s0, y0, s1, y1 "
*
* caller should g_free() result after use.
*/
char*
build_s_y_fields (const guint n)
{
GString *s = NULL;
g_assert_cmpint (n, >=, 1);
s = g_string_new (NULL);
for (guint i = 0; i < n; ++i) {
g_string_append_printf (s, ", s%u, y%u", i, i);
}
s = g_string_append (s, " ");
gchar *result = s->str;
g_string_free (s, FALSE);
return result;
}
/**
* build a SQL to query candidates for first n pinyin in pylist.
* n can be from 1 to len(pylist).
*
* caller should free result with g_free() after use.
*/
static char*
build_sql_for_n_pinyin (GList* pylist,
const guint n,
const guint limit)
{
/* always keep one space after current term */
GString* sql = NULL;
gchar* where_clause = NULL;
sql = g_string_new ("SELECT MAX(user_freq) AS user_freq, "
"phrase, MAX(freq) AS freq");
gchar* s_y_fields = build_s_y_fields (n);
g_string_append_printf (sql, s_y_fields);
g_string_append_printf (sql, "FROM (");
g_string_append_printf (
sql, "SELECT 0 AS user_freq, phrase, freq");
g_string_append_printf (sql, s_y_fields);
g_string_append_printf (
sql, "FROM maindb.py_phrase_%u WHERE ", n - 1);
where_clause = build_where_clause (pylist, n);
g_assert_nonnull (where_clause);
sql = g_string_append (sql, where_clause);
sql = g_string_append (sql, "UNION ");
g_string_append_printf (
sql, "SELECT user_freq, phrase, freq");
g_string_append_printf (sql, s_y_fields);
g_string_append_printf (
sql, "FROM userdb.py_phrase_%u WHERE ", n - 1);
sql = g_string_append (sql, where_clause);
sql = g_string_append (
sql, ") "
"WHERE phrase NOT IN (SELECT phrase FROM userdb.not_phrase) "
"GROUP BY phrase "
"ORDER BY user_freq DESC, freq DESC ");
g_string_append_printf (sql, "LIMIT %u;", limit);
char* result = sql->str;
g_free (where_clause);
g_string_free (sql, FALSE);
return result;
}
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/**
* For a candidate of length group_size, calculate the matched py length.
*
* This is part of get_candidates_for_n_pinyin().
*
* see param meaning there.
*/
static guint
get_matched_py_length (const char* preedit_str,
GList* pylist,
const guint group_size)
{
guint matched_py_length = 0;
GList* iter = pylist;
g_assert_cmpint (group_size, >=, 1);
/* For usual pinyin string, just add up the Pinyin length. But for
* pinyin that contains ', when a Pinyin in pylist is used, also take
* the ' before and after it. */
for (guint i = 0; i < group_size; ++i) {
while (preedit_str[matched_py_length] == '\'') {
matched_py_length++;
}
matched_py_length += ((Pinyin*) iter->data)->length;
while (preedit_str[matched_py_length] == '\'') {
matched_py_length++;
}
iter = iter->next;
}
return matched_py_length;
}
/**
* fetch candidates for a fixed word length.
*
* @db: sqlite3 db handler.
* @preedit_str: the pinyin preedit str. can contain '. This is needed to
* calculate matched_py_length.
* @pylist: the pinyin list.
* @group_size: the fixed word length. use this many pinyin from pinyin list.
* @limit: fetch this many result is enough for user. more is not a problem though.
* @candidates: the result candidate list. caller should free this after use.
*
* returns: how many candidates fetched.
*/
static guint
get_candidates_for_n_pinyin (sqlite3* db,
GList* pylist,
const guint group_size,
const guint limit,
GList** candidates)
{
const guint DEFAULT_LIMIT = 50;
GList* result = NULL; /* GList of Candidate */
g_assert_nonnull (db);
g_assert_cmpint (group_size, >=, 1);
g_assert_cmpint (group_size, <=, g_list_length (pylist));
gint candidates_count = 0;
gint r = 0;
/* build SQL and run SQL query */
char* sql = NULL;
sql = build_sql_for_n_pinyin (pylist, group_size, MAX (limit, DEFAULT_LIMIT));
g_debug ("build_sql_for_n_pinyin result SQL:\n\n%s\n", sql);
guint matched_py_length = get_matched_py_length (preedit_str, pylist, group_size);
sqlite3_stmt* stmt = NULL;
const char* unused;
Candidate* c = NULL;
r = sqlite3_prepare_v2 (db, sql, -1, &stmt, &unused);
g_assert_nonnull (unused);
g_assert_cmpstr (unused, ==, "");
if (strlen (unused)) {
g_warning ("part of sql is unused \"%s\" length=%zu",
unused, strlen (unused));
}
while (TRUE) {
r = sqlite3_step (stmt);
if (r == SQLITE_DONE) {
break;
} else if (r == SQLITE_ROW) {
c = g_new0 (Candidate, 1);
/* sql SELECT should select these columns in order */
c->user_freq = sqlite3_column_int (stmt, 0);
c->str = g_strdup ((const char*) sqlite3_column_text (stmt, 1));
c->freq = sqlite3_column_int (stmt, 2);
c->matched_py_length = matched_py_length;
c->char_len = group_size;
c->py_indices = g_malloc0 (sizeof (Pinyin*) * group_size);
for (guint i = 0; i < group_size; ++i) {
c->py_indices[i] = g_new0 (Pinyin, 1);
c->py_indices[i]->shengmu_i = sqlite3_column_int (stmt, 3 + i * 2);
c->py_indices[i]->yunmu_i = sqlite3_column_int (stmt, 4 + i * 2);
/* we don't care about ->length field */
}
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
if (g_utf8_validate (c->str, -1, NULL)) {
result = g_list_prepend (result, c);
candidates_count++;
} else {
g_warning ("ignore non utf8 phrase: %s", c->str);
}
} else if (r == SQLITE_BUSY) {
g_warning ("sqlite3_step got SQLITE_BUSY");
break;
} else {
g_warning ("sqlite3_step error: %d (%s)",
r, sqlite3_errmsg (db));
break;
}
}
r = sqlite3_finalize (stmt);
if (r != SQLITE_OK) {
g_debug ("sqlite3_finalize error: %d (%s)", r, sqlite3_errmsg (db));
}
/* store query result in a new GList */
*candidates = g_list_reverse (result);
return candidates_count;
}
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
static void
add_candidate_to_builders (Candidate *c,
GVariantBuilder *candidates_builder,
GVariantBuilder *matched_lengths_builder,
GVariantBuilder *candidates_pinyin_indices)
{
g_variant_builder_add (candidates_builder, "s", c->str);
g_variant_builder_add (matched_lengths_builder, "u", c->matched_py_length);
GVariantBuilder *py_indices_builder = NULL;
py_indices_builder = g_variant_builder_new (G_VARIANT_TYPE ("a(ii)"));
for (guint i = 0; i < c->char_len; ++i) {
g_variant_builder_add (
py_indices_builder, "(ii)",
c->py_indices[i]->shengmu_i,
c->py_indices[i]->yunmu_i);
g_debug ("adding (ii) %d %d",
c->py_indices[i]->shengmu_i,
c->py_indices[i]->yunmu_i);
g_free (c->py_indices[i]);
}
g_debug ("adding a(ii) to aa(ii)");
g_variant_builder_add (candidates_pinyin_indices, "a(ii)",
py_indices_builder);
g_variant_builder_unref (py_indices_builder);
g_free (c->str);
g_free (c->py_indices);
}
void
get_candidates (sqlite3* db,
const char* preedit_str,
const guint fetch_size,
GVariantBuilder *candidates_builder,
GVariantBuilder *matched_lengths_builder,
GVariantBuilder *candidates_pinyin_indices)
if (! db) {
g_warning ("No db connection, can't get candidates.");
return;
}
GList* pylist = NULL;
guint pylist_len = 0;
pylist = parse_pinyin (preedit_str, 15);
pylist_len = g_list_length (pylist);
guint group_size = pylist_len;
guint fetched_size = 0;
guint r = 0;
GList* candidates = NULL;
while (fetched_size < fetch_size && group_size > 0) {
g_message ("phrase length=%u", group_size);
r = get_candidates_for_n_pinyin (db, preedit_str, pylist, group_size, fetch_size - fetched_size, &candidates);
if (candidates) {
GList* iter = g_list_first (candidates);
Candidate* c = NULL;
while (iter != NULL) {
c = (Candidate*) iter->data;
add_candidate_to_builders (
c, candidates_builder,
matched_lengths_builder,
candidates_pinyin_indices);
iter = iter->next;
}
g_list_free_full (candidates, g_free);
}
g_message ("%u candidates found", r);
fetched_size += r;
group_size--;
}
g_message ("returning %u candidates", fetched_size);
g_list_free_full (pylist, g_free);
}
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
/**
* sub function for commit_candidate()
*/
static void
_update_userdb_py_phrase (sqlite3 *db,
const gchar *candidate,
GVariant *candidate_pinyin_indices,
guint len) /* utf8 length of candidate char */
{
GString *sql = NULL;
GVariantIter iter = {0};
GVariant *child = NULL;
gint x = 0;
gint y = 0;
guint count = 0;
char *s = NULL;
gboolean rb = FALSE;
g_assert_nonnull (db);
g_assert_nonnull (candidate);
g_assert_nonnull (candidate_pinyin_indices);
/* insert candidate maybe */
sql = g_string_new (NULL);
g_string_append_printf (sql, "INSERT OR IGNORE INTO userdb.py_phrase_%u (user_freq, phrase, freq", len - 1);
gchar* s_y_fields = build_s_y_fields (len);
sql = g_string_append (sql, s_y_fields);
g_free (s_y_fields);
s = sqlite3_mprintf (") VALUES (0, %Q, 0", candidate);
sql = g_string_append (sql, s);
sqlite3_free (s);
/* iter over GVariant "a(ii)" */
g_variant_iter_init (&iter, candidate_pinyin_indices);
count = 0;
while ((child = g_variant_iter_next_value (&iter))) {
g_variant_get (child, "(ii)", &x, &y);
g_string_append_printf (sql, ", %d, %d", x, y);
count++;
}
if (count != len) {
g_warning ("candidate length=%u, a(ii) length=%u, mismatch!",
len, count);
g_string_free (sql, TRUE);
g_assert_not_reached ();
return;
}
g_string_append_printf (sql, ");");
rb = sqlite3_exec_simple (db, sql->str);
if (! rb) {
g_warning ("INSERT candidate to userdb failed");
} else {
if (sqlite3_changes (db) == 1) {
g_message ("candidate %s inserted to userdb", candidate);
}
}
g_string_free (sql, TRUE);
/* increment user_freq field for candidate */
sql = g_string_new (NULL);
g_string_append_printf (sql, "UPDATE userdb.py_phrase_%u "
"SET user_freq = user_freq + 1 ", len - 1);
s = sqlite3_mprintf ("WHERE phrase = %Q ", candidate);
sql = g_string_append (sql, s);
sqlite3_free (s);
g_variant_iter_init (&iter, candidate_pinyin_indices);
count = 0;
while ((child = g_variant_iter_next_value (&iter))) {
g_variant_get (child, "(ii)", &x, &y);
g_string_append_printf (sql, "AND s%d=%d AND y%d=%d ",
count, x, count, y);
count++;
}
sql = g_string_append (sql, ";");
rb = sqlite3_exec_simple (db, sql->str);
if (! rb) {
g_warning ("UPDATE candidate user_freq failed");
} else {
if (sqlite3_changes (db) == 1) {
g_message ("candidate %s user_freq incremented", candidate);
} else {
g_warning ("UPDATE candidate user_freq failed, no match");
}
}
g_string_free (sql, TRUE);
}
static void
_update_userdb_not_phrase (sqlite3 *db,
const gchar *candidate)
{
g_assert_nonnull (db);
g_assert_nonnull (candidate);
gboolean rb = FALSE;
char *sql = sqlite3_mprintf ("DELETE FROM userdb.not_phrase WHERE phrase = %Q;", candidate);
rb = sqlite3_exec_simple (db, sql);
if (! rb) {
g_warning ("DELETE candidate from not_phrase failed");
} else {
if (sqlite3_changes (db) == 1) {
g_message ("candidate %s removed from not_phrase", candidate);
}
}
sqlite3_free (sql);
}
void
commit_candidate (sqlite3 *db,
const gchar *candidate,
GVariant *candidate_pinyin_indices)
{
if (! db) {
g_warning ("No db connection, can't commit candidates.");
return;
}
if (! candidate) {
g_warning ("candidate should not be NULL. won't commit candidate.");
return;
}
if (! candidate_pinyin_indices) {
g_warning ("candidate_pinyin_indices should not be NULL. won't commit candidate.");
return;
}
guint len = g_utf8_strlen (candidate, -1);
if (len <= 1) {
g_message ("commit single character %s is a no-op", candidate);
return;
}
_update_userdb_py_phrase (db, candidate, candidate_pinyin_indices, len);
_update_userdb_not_phrase (db, candidate);
}