]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/logging.py
simplify handling of SQL lookup code for search_name
[nominatim.git] / nominatim / api / logging.py
index 6c8b1b388224f8a787977b15e7c2e2fce5aaab03..e16e0bd2d3bdbcab64b7f8c074ddbbe72cc4843e 100644 (file)
@@ -7,11 +7,12 @@
 """
 Functions for specialised logging with HTML output.
 """
 """
 Functions for specialised logging with HTML output.
 """
-from typing import Any, Iterator, Optional, List, Tuple, cast
+from typing import Any, Iterator, Optional, List, Tuple, cast, Union, Mapping, Sequence
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
+import re
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -74,23 +75,57 @@ class BaseLogger:
         """
 
 
         """
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         """ Print the SQL for the given statement.
         """
 
         """ Print the SQL for the given statement.
         """
 
-    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> str:
+    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+                   extra_params: Union[Mapping[str, Any],
+                                 Sequence[Mapping[str, Any]], None]) -> str:
         """ Return the comiled version of the statement.
         """
         """ Return the comiled version of the statement.
         """
-        try:
-            return str(cast('sa.ClauseElement', statement)
-                         .compile(conn.sync_engine, compile_kwargs={"literal_binds": True}))
-        except sa.exc.CompileError:
-            pass
-        except NotImplementedError:
-            pass
-
-        return str(cast('sa.ClauseElement', statement).compile(conn.sync_engine))
-
+        compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
+
+        params = dict(compiled.params)
+        if isinstance(extra_params, Mapping):
+            for k, v in extra_params.items():
+                if hasattr(v, 'to_wkt'):
+                    params[k] = v.to_wkt()
+                elif isinstance(v, (int, float)):
+                    params[k] = v
+                else:
+                    params[k] = str(v)
+        elif isinstance(extra_params, Sequence) and extra_params:
+            for k in extra_params[0]:
+                params[k] = f':{k}'
+
+        sqlstr = str(compiled)
+
+        if conn.dialect.name == 'postgresql':
+            if sa.__version__.startswith('1'):
+                try:
+                    sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
+                    return sqlstr % tuple((repr(params.get(name, None))
+                                          for name in compiled.positiontup)) # type: ignore
+                except TypeError:
+                    return sqlstr
+
+            # Fixes an odd issue with Python 3.7 where percentages are not
+            # quoted correctly.
+            sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
+            sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
+            return sqlstr % params
+
+        assert conn.dialect.name == 'sqlite'
+
+        # params in positional order
+        pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore
+
+        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr)
+        sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr)
+
+        return sqlstr
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
@@ -183,9 +218,10 @@ class HTMLLogger(BaseLogger):
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         self._timestamp()
         self._timestamp()
-        sqlstr = self.format_sql(conn, statement)
+        sqlstr = self.format_sql(conn, statement, params)
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
@@ -215,6 +251,10 @@ class TextLogger(BaseLogger):
         self.buffer = io.StringIO()
 
 
         self.buffer = io.StringIO()
 
 
+    def _timestamp(self) -> None:
+        self._write(f'[{dt.datetime.now()}]\n')
+
+
     def get_buffer(self) -> str:
         return self.buffer.getvalue()
 
     def get_buffer(self) -> str:
         return self.buffer.getvalue()
 
@@ -227,6 +267,7 @@ class TextLogger(BaseLogger):
 
 
     def section(self, heading: str) -> None:
 
 
     def section(self, heading: str) -> None:
+        self._timestamp()
         self._write(f"\n# {heading}\n\n")
 
 
         self._write(f"\n# {heading}\n\n")
 
 
@@ -263,6 +304,7 @@ class TextLogger(BaseLogger):
 
 
     def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
 
 
     def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
+        self._timestamp()
         self._write(f'{heading}:\n')
         total = 0
         for rank, res in results:
         self._write(f'{heading}:\n')
         total = 0
         for rank, res in results:
@@ -276,8 +318,10 @@ class TextLogger(BaseLogger):
         self._write(f'TOTAL: {total}\n\n')
 
 
         self._write(f'TOTAL: {total}\n\n')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
-        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
+        self._timestamp()
+        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
         self._write(f"| {sqlstr}\n\n")
 
 
         self._write(f"| {sqlstr}\n\n")