]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/search/query.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / search / query.py
index bc1f542d10148aeb79f3ee48240d5e4f7040dbc0..ad1b69ef521dfd303ae1d5b95aa8629859feb10b 100644 (file)
@@ -7,7 +7,7 @@
 """
 Datastructures for a tokenized query.
 """
 """
 Datastructures for a tokenized query.
 """
-from typing import List, Tuple, Optional, NamedTuple
+from typing import List, Tuple, Optional, Iterator
 from abc import ABC, abstractmethod
 import dataclasses
 import enum
 from abc import ABC, abstractmethod
 import dataclasses
 import enum
@@ -46,7 +46,7 @@ class TokenType(enum.Enum):
     """ Country name or reference. """
     QUALIFIER = enum.auto()
     """ Special term used together with name (e.g. _Hotel_ Bellevue). """
     """ Country name or reference. """
     QUALIFIER = enum.auto()
     """ Special term used together with name (e.g. _Hotel_ Bellevue). """
-    CATEGORY = enum.auto()
+    NEAR_ITEM = enum.auto()
     """ Special term used as searchable object(e.g. supermarket in ...). """
 
 
     """ Special term used as searchable object(e.g. supermarket in ...). """
 
 
@@ -70,14 +70,16 @@ class PhraseType(enum.Enum):
     COUNTRY = enum.auto()
     """ Contains the country name or code. """
 
     COUNTRY = enum.auto()
     """ Contains the country name or code. """
 
-    def compatible_with(self, ttype: TokenType) -> bool:
+    def compatible_with(self, ttype: TokenType,
+                        is_full_phrase: bool) -> bool:
         """ Check if the given token type can be used with the phrase type.
         """
         if self == PhraseType.NONE:
         """ Check if the given token type can be used with the phrase type.
         """
         if self == PhraseType.NONE:
-            return True
+            return not is_full_phrase or ttype != TokenType.QUALIFIER
         if self == PhraseType.AMENITY:
         if self == PhraseType.AMENITY:
-            return ttype in (TokenType.WORD, TokenType.PARTIAL,
-                             TokenType.QUALIFIER, TokenType.CATEGORY)
+            return ttype in (TokenType.WORD, TokenType.PARTIAL)\
+                   or (is_full_phrase and ttype == TokenType.NEAR_ITEM)\
+                   or (not is_full_phrase and ttype == TokenType.QUALIFIER)
         if self == PhraseType.STREET:
             return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
         if self == PhraseType.POSTCODE:
         if self == PhraseType.STREET:
             return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
         if self == PhraseType.POSTCODE:
@@ -107,13 +109,47 @@ class Token(ABC):
             category objects.
         """
 
             category objects.
         """
 
-
-class TokenRange(NamedTuple):
+@dataclasses.dataclass
+class TokenRange:
     """ Indexes of query nodes over which a token spans.
     """
     start: int
     end: int
 
     """ Indexes of query nodes over which a token spans.
     """
     start: int
     end: int
 
+    def __lt__(self, other: 'TokenRange') -> bool:
+        return self.end <= other.start
+
+
+    def __le__(self, other: 'TokenRange') -> bool:
+        return NotImplemented
+
+
+    def __gt__(self, other: 'TokenRange') -> bool:
+        return self.start >= other.end
+
+
+    def __ge__(self, other: 'TokenRange') -> bool:
+        return NotImplemented
+
+
+    def replace_start(self, new_start: int) -> 'TokenRange':
+        """ Return a new token range with the new start.
+        """
+        return TokenRange(new_start, self.end)
+
+
+    def replace_end(self, new_end: int) -> 'TokenRange':
+        """ Return a new token range with the new end.
+        """
+        return TokenRange(self.start, new_end)
+
+
+    def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
+        """ Split the span into two spans at the given index.
+            The index must be within the span.
+        """
+        return self.replace_end(index), self.replace_start(index)
+
 
 @dataclasses.dataclass
 class TokenList:
 
 @dataclasses.dataclass
 class TokenList:
@@ -124,6 +160,13 @@ class TokenList:
     tokens: List[Token]
 
 
     tokens: List[Token]
 
 
+    def add_penalty(self, penalty: float) -> None:
+        """ Add the given penalty to all tokens in the list.
+        """
+        for token in self.tokens:
+            token.penalty += penalty
+
+
 @dataclasses.dataclass
 class QueryNode:
     """ A node of the querry representing a break between terms.
 @dataclasses.dataclass
 class QueryNode:
     """ A node of the querry representing a break between terms.
@@ -144,7 +187,10 @@ class QueryNode:
             and ending at the node 'end'. Returns 'None' if no such
             tokens exist.
         """
             and ending at the node 'end'. Returns 'None' if no such
             tokens exist.
         """
-        return next((t.tokens for t in self.starting if t.end == end and t.ttype == ttype), None)
+        for tlist in self.starting:
+            if tlist.end == end and tlist.ttype == ttype:
+                return tlist.tokens
+        return None
 
 
 @dataclasses.dataclass
 
 
 @dataclasses.dataclass
@@ -200,7 +246,9 @@ class QueryStruct:
             be added to, then the token is silently dropped.
         """
         snode = self.nodes[trange.start]
             be added to, then the token is silently dropped.
         """
         snode = self.nodes[trange.start]
-        if snode.ptype.compatible_with(ttype):
+        full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
+                      and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
+        if snode.ptype.compatible_with(ttype, full_phrase):
             tlist = snode.get_tokens(trange.end, ttype)
             if tlist is None:
                 snode.starting.append(TokenList(trange.end, ttype, [token]))
             tlist = snode.get_tokens(trange.end, ttype)
             if tlist is None:
                 snode.starting.append(TokenList(trange.end, ttype, [token]))
@@ -226,6 +274,14 @@ class QueryStruct:
                           for i in range(trange.start, trange.end)]
 
 
                           for i in range(trange.start, trange.end)]
 
 
+    def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
+        """ Iterator over all token lists in the query.
+        """
+        for i, node in enumerate(self.nodes):
+            for tlist in node.starting:
+                yield i, node, tlist
+
+
     def find_lookup_word_by_id(self, token: int) -> str:
         """ Find the first token with the given token ID and return
             its lookup word. Returns 'None' if no such token exists.
     def find_lookup_word_by_id(self, token: int) -> str:
         """ Find the first token with the given token ID and return
             its lookup word. Returns 'None' if no such token exists.