]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge pull request #3692 from lonvia/word-lookup-variants
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 31 Mar 2025 14:38:31 +0000 (16:38 +0200)
committerGitHub <noreply@github.com>
Mon, 31 Mar 2025 14:38:31 +0000 (16:38 +0200)
Avoid matching penalty for abbreviated search terms

lib-sql/tokenizer/icu_tokenizer.sql
src/nominatim_db/tokenizer/icu_tokenizer.py
src/nominatim_db/tokenizer/token_analysis/base.py
src/nominatim_db/tokenizer/token_analysis/generic.py
test/python/tokenizer/test_icu.py
test/python/tokenizer/token_analysis/test_generic.py
test/python/tokenizer/token_analysis/test_generic_mutation.py

index f0c30f1b786471d210cd96b28fd49fd149487ea2..8cf13120a11515db4fbe432c131332fdea054a45 100644 (file)
@@ -128,16 +128,14 @@ DECLARE
   partial_terms TEXT[] = '{}'::TEXT[];
   term TEXT;
   term_id INTEGER;
-  term_count INTEGER;
 BEGIN
   SELECT min(word_id) INTO full_token
     FROM word WHERE word = norm_term and type = 'W';
 
   IF full_token IS NULL THEN
     full_token := nextval('seq_word');
-    INSERT INTO word (word_id, word_token, type, word, info)
-      SELECT full_token, lookup_term, 'W', norm_term,
-             json_build_object('count', 0)
+    INSERT INTO word (word_id, word_token, type, word)
+      SELECT full_token, lookup_term, 'W', norm_term
         FROM unnest(lookup_terms) as lookup_term;
   END IF;
 
@@ -150,14 +148,67 @@ BEGIN
 
   partial_tokens := '{}'::INT[];
   FOR term IN SELECT unnest(partial_terms) LOOP
-    SELECT min(word_id), max(info->>'count') INTO term_id, term_count
+    SELECT min(word_id) INTO term_id
+      FROM word WHERE word_token = term and type = 'w';
+
+    IF term_id IS NULL THEN
+      term_id := nextval('seq_word');
+      INSERT INTO word (word_id, word_token, type)
+        VALUES (term_id, term, 'w');
+    END IF;
+
+    partial_tokens := array_merge(partial_tokens, ARRAY[term_id]);
+  END LOOP;
+END;
+$$
+LANGUAGE plpgsql;
+
+
+CREATE OR REPLACE FUNCTION getorcreate_full_word(norm_term TEXT,
+                                                 lookup_terms TEXT[],
+                                                 lookup_norm_terms TEXT[],
+                                                 OUT full_token INT,
+                                                 OUT partial_tokens INT[])
+  AS $$
+DECLARE
+  partial_terms TEXT[] = '{}'::TEXT[];
+  term TEXT;
+  term_id INTEGER;
+BEGIN
+  SELECT min(word_id) INTO full_token
+    FROM word WHERE word = norm_term and type = 'W';
+
+  IF full_token IS NULL THEN
+    full_token := nextval('seq_word');
+    IF lookup_norm_terms IS NULL THEN
+      INSERT INTO word (word_id, word_token, type, word)
+        SELECT full_token, lookup_term, 'W', norm_term
+          FROM unnest(lookup_terms) as lookup_term;
+    ELSE
+      INSERT INTO word (word_id, word_token, type, word, info)
+        SELECT full_token, t.lookup, 'W', norm_term,
+               CASE WHEN norm_term = t.norm THEN null
+               ELSE json_build_object('lookup', t.norm) END
+          FROM unnest(lookup_terms, lookup_norm_terms) as t(lookup, norm);
+    END IF;
+  END IF;
+
+  FOR term IN SELECT unnest(string_to_array(unnest(lookup_terms), ' ')) LOOP
+    term := trim(term);
+    IF NOT (ARRAY[term] <@ partial_terms) THEN
+      partial_terms := partial_terms || term;
+    END IF;
+  END LOOP;
+
+  partial_tokens := '{}'::INT[];
+  FOR term IN SELECT unnest(partial_terms) LOOP
+    SELECT min(word_id) INTO term_id
       FROM word WHERE word_token = term and type = 'w';
 
     IF term_id IS NULL THEN
       term_id := nextval('seq_word');
-      term_count := 0;
-      INSERT INTO word (word_id, word_token, type, info)
-        VALUES (term_id, term, 'w', json_build_object('count', term_count));
+      INSERT INTO word (word_id, word_token, type)
+        VALUES (term_id, term, 'w');
     END IF;
 
     partial_tokens := array_merge(partial_tokens, ARRAY[term_id]);
index 3da1171f5aed0fee93c12610058e8b1a8edfa143..297c9ef9a401e02413d8c223f9633c0d4ad36e33 100644 (file)
@@ -121,10 +121,10 @@ class ICUTokenizer(AbstractTokenizer):
                            SELECT unnest(nameaddress_vector) as id, count(*)
                                  FROM search_name GROUP BY id)
                   SELECT coalesce(a.id, w.id) as id,
-                         (CASE WHEN w.count is null THEN '{}'::JSONB
+                         (CASE WHEN w.count is null or w.count <= 1 THEN '{}'::JSONB
                               ELSE jsonb_build_object('count', w.count) END
                           ||
-                          CASE WHEN a.count is null THEN '{}'::JSONB
+                          CASE WHEN a.count is null or a.count <= 1 THEN '{}'::JSONB
                               ELSE jsonb_build_object('addr_count', a.count) END) as info
                   FROM word_freq w FULL JOIN addr_freq a ON a.id = w.id;
                   """)
@@ -134,9 +134,10 @@ class ICUTokenizer(AbstractTokenizer):
                 drop_tables(conn, 'tmp_word')
                 cur.execute("""CREATE TABLE tmp_word AS
                                 SELECT word_id, word_token, type, word,
-                                       (CASE WHEN wf.info is null THEN word.info
-                                        ELSE coalesce(word.info, '{}'::jsonb) || wf.info
-                                        END) as info
+                                       coalesce(word.info, '{}'::jsonb)
+                                       - 'count' - 'addr_count' ||
+                                       coalesce(wf.info, '{}'::jsonb)
+                                       as info
                                 FROM word LEFT JOIN word_frequencies wf
                                      ON word.word_id = wf.id
                             """)
@@ -584,10 +585,14 @@ class ICUNameAnalyzer(AbstractAnalyzer):
             if word_id:
                 result = self._cache.housenumbers.get(word_id, result)
                 if result[0] is None:
-                    variants = analyzer.compute_variants(word_id)
+                    varout = analyzer.compute_variants(word_id)
+                    if isinstance(varout, tuple):
+                        variants = varout[0]
+                    else:
+                        variants = varout
                     if variants:
                         hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
-                                             (word_id, list(variants)))
+                                             (word_id, variants))
                         result = hid, variants[0]
                         self._cache.housenumbers[word_id] = result
 
@@ -632,13 +637,17 @@ class ICUNameAnalyzer(AbstractAnalyzer):
 
             full, part = self._cache.names.get(token_id, (None, None))
             if full is None:
-                variants = analyzer.compute_variants(word_id)
+                varset = analyzer.compute_variants(word_id)
+                if isinstance(varset, tuple):
+                    variants, lookups = varset
+                else:
+                    variants, lookups = varset, None
                 if not variants:
                     continue
 
                 with self.conn.cursor() as cur:
-                    cur.execute("SELECT * FROM getorcreate_full_word(%s, %s)",
-                                (token_id, variants))
+                    cur.execute("SELECT * FROM getorcreate_full_word(%s, %s, %s)",
+                                (token_id, variants, lookups))
                     full, part = cast(Tuple[int, List[int]], cur.fetchone())
 
                 self._cache.names[token_id] = (full, part)
index 52ee801343fb6c29d2425dd356c0b9df495a745b..186f1d3ebc1113575a41929b605cfc32d8f5ecb2 100644 (file)
@@ -7,7 +7,7 @@
 """
 Common data types and protocols for analysers.
 """
-from typing import Mapping, List, Any
+from typing import Mapping, List, Any, Union, Tuple
 
 from ...typing import Protocol
 from ...data.place_name import PlaceName
@@ -33,7 +33,7 @@ class Analyzer(Protocol):
                     for example because the character set in use does not match.
         """
 
-    def compute_variants(self, canonical_id: str) -> List[str]:
+    def compute_variants(self, canonical_id: str) -> Union[List[str], Tuple[List[str], List[str]]]:
         """ Compute the transliterated spelling variants for the given
             canonical ID.
 
index fa9dc4dfa54c66e6f25a408129ba327d228b38b0..b01cebf75e78081fbd74a91966cfe3a3876da36a 100644 (file)
@@ -7,7 +7,7 @@
 """
 Generic processor for names that creates abbreviation variants.
 """
-from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
+from typing import Mapping, Dict, Any, Iterable, Optional, List, cast, Tuple
 import itertools
 
 from ...errors import UsageError
@@ -78,7 +78,7 @@ class GenericTokenAnalysis:
         """
         return cast(str, self.norm.transliterate(name.name)).strip()
 
-    def compute_variants(self, norm_name: str) -> List[str]:
+    def compute_variants(self, norm_name: str) -> Tuple[List[str], List[str]]:
         """ Compute the spelling variants for the given normalized name
             and transliterate the result.
         """
@@ -87,18 +87,20 @@ class GenericTokenAnalysis:
         for mutation in self.mutations:
             variants = mutation.generate(variants)
 
-        return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
-
-    def _transliterate_unique_list(self, norm_name: str,
-                                   iterable: Iterable[str]) -> Iterator[Optional[str]]:
-        seen = set()
+        varset = set(map(str.strip, variants))
         if self.variant_only:
-            seen.add(norm_name)
+            varset.discard(norm_name)
+
+        trans = []
+        norm = []
+
+        for var in varset:
+            t = self.to_ascii.transliterate(var).strip()
+            if t:
+                trans.append(t)
+                norm.append(var)
 
-        for variant in map(str.strip, iterable):
-            if variant not in seen:
-                seen.add(variant)
-                yield self.to_ascii.transliterate(variant).strip()
+        return trans, norm
 
     def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
         baseform = '^ ' + norm_name + ' ^'
index ce00281cff7e09e850fb4d5ee957fb687f65d809..12cef894f863cb2336b1be26240e28fb4a8ba28e 100644 (file)
@@ -230,19 +230,20 @@ def test_update_statistics(word_table, table_factory, temp_db_cursor,
                            tokenizer_factory, test_config):
     word_table.add_full_word(1000, 'hello')
     word_table.add_full_word(1001, 'bye')
+    word_table.add_full_word(1002, 'town')
     table_factory('search_name',
                   'place_id BIGINT, name_vector INT[], nameaddress_vector INT[]',
-                  [(12, [1000], [1001])])
+                  [(12, [1000], [1001]), (13, [1001], [1002]), (14, [1000, 1001], [1002])])
     tok = tokenizer_factory()
 
     tok.update_statistics(test_config)
 
-    assert temp_db_cursor.scalar("""SELECT count(*) FROM word
-                                    WHERE type = 'W' and word_id = 1000 and
-                                          (info->>'count')::int > 0""") == 1
-    assert temp_db_cursor.scalar("""SELECT count(*) FROM word
-                                    WHERE type = 'W' and word_id = 1001 and
-                                          (info->>'addr_count')::int > 0""") == 1
+    assert temp_db_cursor.row_set("""SELECT word_id,
+                                            (info->>'count')::int,
+                                            (info->>'addr_count')::int
+                                     FROM word
+                                     WHERE type = 'W'""") == \
+        {(1000, 2, None), (1001, 2, None), (1002, None, 2)}
 
 
 def test_normalize_postcode(analyzer):
index 02870f2445e7b798a7bc29ffa5aa794d2dd1a9a1..48f2483bc8e54cfe2a6ceed3d89a2efe14c5561e 100644 (file)
@@ -40,7 +40,7 @@ def make_analyser(*variants, variant_only=False):
 
 def get_normalized_variants(proc, name):
     norm = Transliterator.createFromRules("test_norm", DEFAULT_NORMALIZATION)
-    return proc.compute_variants(norm.transliterate(name).strip())
+    return proc.compute_variants(norm.transliterate(name).strip())[0]
 
 
 def test_no_variants():
index 2ce2236a3c490762c965c42a75cdcc4344661c54..e0507e4c30ace40f5d4ead0c0889b10cb8eddb6b 100644 (file)
@@ -40,7 +40,7 @@ class TestMutationNoVariants:
 
     def variants(self, name):
         norm = Transliterator.createFromRules("test_norm", DEFAULT_NORMALIZATION)
-        return set(self.analysis.compute_variants(norm.transliterate(name).strip()))
+        return set(self.analysis.compute_variants(norm.transliterate(name).strip())[0])
 
     @pytest.mark.parametrize('pattern', ('(capture)', ['a list']))
     def test_bad_pattern(self, pattern):