]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/logging.py
when a country is in the results, restrict further searches to places
[nominatim.git] / nominatim / api / logging.py
index a4da357d651bccb358dec2638cbf76b60bcfdbd4..e16e0bd2d3bdbcab64b7f8c074ddbbe72cc4843e 100644 (file)
@@ -7,11 +7,12 @@
 """
 Functions for specialised logging with HTML output.
 """
 """
 Functions for specialised logging with HTML output.
 """
-from typing import Any, Iterator, Optional, List, Tuple, cast
+from typing import Any, Iterator, Optional, List, Tuple, cast, Union, Mapping, Sequence
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
+import re
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -74,23 +75,57 @@ class BaseLogger:
         """
 
 
         """
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         """ Print the SQL for the given statement.
         """
 
         """ Print the SQL for the given statement.
         """
 
-    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> str:
+    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+                   extra_params: Union[Mapping[str, Any],
+                                 Sequence[Mapping[str, Any]], None]) -> str:
         """ Return the comiled version of the statement.
         """
         """ Return the comiled version of the statement.
         """
-        try:
-            return str(cast('sa.ClauseElement', statement)
-                         .compile(conn.sync_engine, compile_kwargs={"literal_binds": True}))
-        except sa.exc.CompileError:
-            pass
-        except NotImplementedError:
-            pass
-
-        return str(cast('sa.ClauseElement', statement).compile(conn.sync_engine))
-
+        compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
+
+        params = dict(compiled.params)
+        if isinstance(extra_params, Mapping):
+            for k, v in extra_params.items():
+                if hasattr(v, 'to_wkt'):
+                    params[k] = v.to_wkt()
+                elif isinstance(v, (int, float)):
+                    params[k] = v
+                else:
+                    params[k] = str(v)
+        elif isinstance(extra_params, Sequence) and extra_params:
+            for k in extra_params[0]:
+                params[k] = f':{k}'
+
+        sqlstr = str(compiled)
+
+        if conn.dialect.name == 'postgresql':
+            if sa.__version__.startswith('1'):
+                try:
+                    sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
+                    return sqlstr % tuple((repr(params.get(name, None))
+                                          for name in compiled.positiontup)) # type: ignore
+                except TypeError:
+                    return sqlstr
+
+            # Fixes an odd issue with Python 3.7 where percentages are not
+            # quoted correctly.
+            sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
+            sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
+            return sqlstr % params
+
+        assert conn.dialect.name == 'sqlite'
+
+        # params in positional order
+        pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore
+
+        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr)
+        sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr)
+
+        return sqlstr
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
@@ -178,14 +213,15 @@ class HTMLLogger(BaseLogger):
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
-            self._write(f'importance={res.importance or -1:.5f})</dd>')
+            self._write(f'importance={res.importance or float("nan"):.5f})</dd>')
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         self._timestamp()
         self._timestamp()
-        sqlstr = self.format_sql(conn, statement)
+        sqlstr = self.format_sql(conn, statement, params)
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
@@ -215,6 +251,10 @@ class TextLogger(BaseLogger):
         self.buffer = io.StringIO()
 
 
         self.buffer = io.StringIO()
 
 
+    def _timestamp(self) -> None:
+        self._write(f'[{dt.datetime.now()}]\n')
+
+
     def get_buffer(self) -> str:
         return self.buffer.getvalue()
 
     def get_buffer(self) -> str:
         return self.buffer.getvalue()
 
@@ -227,6 +267,7 @@ class TextLogger(BaseLogger):
 
 
     def section(self, heading: str) -> None:
 
 
     def section(self, heading: str) -> None:
+        self._timestamp()
         self._write(f"\n# {heading}\n\n")
 
 
         self._write(f"\n# {heading}\n\n")
 
 
@@ -263,6 +304,7 @@ class TextLogger(BaseLogger):
 
 
     def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
 
 
     def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+        self._timestamp()
         self._write(f'{heading}:\n')
         total = 0
         for rank, res in results:
         self._write(f'{heading}:\n')
         total = 0
         for rank, res in results:
@@ -276,8 +318,10 @@ class TextLogger(BaseLogger):
         self._write(f'TOTAL: {total}\n\n')
 
 
         self._write(f'TOTAL: {total}\n\n')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
-        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
+        self._timestamp()
+        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
         self._write(f"| {sqlstr}\n\n")
 
 
         self._write(f"| {sqlstr}\n\n")