]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/logging.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / logging.py
index 6bf3ed38f10230f1b1b3b3a48d4be97e4d3bcd39..5b6d0e4dbbc03573b643e88cba43a871d7bd038a 100644 (file)
@@ -7,11 +7,12 @@
 """
 Functions for specialised logging with HTML output.
 """
 """
 Functions for specialised logging with HTML output.
 """
-from typing import Any, Iterator, Optional, List, Tuple, cast
+from typing import Any, Iterator, Optional, List, Tuple, cast, Union, Mapping, Sequence
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 import io
+import re
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -74,23 +75,41 @@ class BaseLogger:
         """
 
 
         """
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         """ Print the SQL for the given statement.
         """
 
         """ Print the SQL for the given statement.
         """
 
-    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> str:
+    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+                   extra_params: Union[Mapping[str, Any],
+                                 Sequence[Mapping[str, Any]], None]) -> str:
         """ Return the comiled version of the statement.
         """
         """ Return the comiled version of the statement.
         """
-        try:
-            return str(cast('sa.ClauseElement', statement)
-                         .compile(conn.sync_engine, compile_kwargs={"literal_binds": True}))
-        except sa.exc.CompileError:
-            pass
-        except NotImplementedError:
-            pass
-
-        return str(cast('sa.ClauseElement', statement).compile(conn.sync_engine))
-
+        compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
+
+        params = dict(compiled.params)
+        if isinstance(extra_params, Mapping):
+            for k, v in extra_params.items():
+                params[k] = str(v)
+        elif isinstance(extra_params, Sequence) and extra_params:
+            for k in extra_params[0]:
+                params[k] = f':{k}'
+
+        sqlstr = str(compiled)
+
+        if sa.__version__.startswith('1'):
+            try:
+                sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
+                return sqlstr % tuple((repr(params.get(name, None))
+                                      for name in compiled.positiontup)) # type: ignore
+            except TypeError:
+                return sqlstr
+
+        # Fixes an odd issue with Python 3.7 where percentages are not
+        # quoted correctly.
+        sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
+        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
+        return sqlstr % params
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
@@ -178,14 +197,15 @@ class HTMLLogger(BaseLogger):
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
-            self._write(f'importance={res.importance or -1:.5f})</dd>')
+            self._write(f'importance={res.importance or float("nan"):.5f})</dd>')
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         self._timestamp()
         self._timestamp()
-        sqlstr = self.format_sql(conn, statement)
+        sqlstr = self.format_sql(conn, statement, params)
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
@@ -196,7 +216,7 @@ class HTMLLogger(BaseLogger):
 
     def _python_var(self, var: Any) -> str:
         if CODE_HIGHLIGHT:
 
     def _python_var(self, var: Any) -> str:
         if CODE_HIGHLIGHT:
-            fmt = highlight(repr(var), PythonLexer(), HtmlFormatter(nowrap=True))
+            fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True))
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
         return f'<code class="lang-python">{str(var)}</code>'
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
         return f'<code class="lang-python">{str(var)}</code>'
@@ -276,8 +296,9 @@ class TextLogger(BaseLogger):
         self._write(f'TOTAL: {total}\n\n')
 
 
         self._write(f'TOTAL: {total}\n\n')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
-        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
+        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
         self._write(f"| {sqlstr}\n\n")
 
 
         self._write(f"| {sqlstr}\n\n")