]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/logging.py
band-aid for SQLAlchemy 1.4
[nominatim.git] / nominatim / api / logging.py
index 6bf3ed38f10230f1b1b3b3a48d4be97e4d3bcd39..202d7de9a5baf0abca6ed1807bfb488cb136de05 100644 (file)
@@ -7,7 +7,7 @@
 """
 Functions for specialised logging with HTML output.
 """
 """
 Functions for specialised logging with HTML output.
 """
-from typing import Any, Iterator, Optional, List, Tuple, cast
+from typing import Any, Iterator, Optional, List, Tuple, cast, Union, Mapping, Sequence
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
 from contextvars import ContextVar
 import datetime as dt
 import textwrap
@@ -74,22 +74,34 @@ class BaseLogger:
         """
 
 
         """
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         """ Print the SQL for the given statement.
         """
 
         """ Print the SQL for the given statement.
         """
 
-    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> str:
+    def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+                   extra_params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> str:
         """ Return the comiled version of the statement.
         """
         """ Return the comiled version of the statement.
         """
-        try:
-            return str(cast('sa.ClauseElement', statement)
-                         .compile(conn.sync_engine, compile_kwargs={"literal_binds": True}))
-        except sa.exc.CompileError:
-            pass
-        except NotImplementedError:
-            pass
+        compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
 
 
-        return str(cast('sa.ClauseElement', statement).compile(conn.sync_engine))
+        params = dict(compiled.params)
+        if isinstance(extra_params, Mapping):
+            for k, v in extra_params.items():
+                params[k] = str(v)
+        elif isinstance(extra_params, Sequence) and extra_params:
+            for k in extra_params[0]:
+                params[k] = f':{k}'
+
+        sqlstr = str(compiled)
+
+        if '%s' in sqlstr:
+            try:
+                return sqlstr % tuple((repr(compiled.params[name]) for name in compiled.positiontup))
+            except TypeError:
+                return sqlstr
+
+        return str(compiled) % params
 
 
 class HTMLLogger(BaseLogger):
 
 
 class HTMLLogger(BaseLogger):
@@ -178,14 +190,15 @@ class HTMLLogger(BaseLogger):
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
             self._write(f"rank={res.rank_address}, ")
             self._write(f"osm={format_osm(res.osm_object)}, ")
             self._write(f'cc={res.country_code}, ')
-            self._write(f'importance={res.importance or -1:.5f})</dd>')
+            self._write(f'importance={res.importance or float("nan"):.5f})</dd>')
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
             total += 1
         self._write(f'</dl><b>TOTAL:</b> {total}</p>')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
         self._timestamp()
         self._timestamp()
-        sqlstr = self.format_sql(conn, statement)
+        sqlstr = self.format_sql(conn, statement, params)
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
         if CODE_HIGHLIGHT:
             sqlstr = highlight(sqlstr, PostgresLexer(),
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
@@ -196,7 +209,7 @@ class HTMLLogger(BaseLogger):
 
     def _python_var(self, var: Any) -> str:
         if CODE_HIGHLIGHT:
 
     def _python_var(self, var: Any) -> str:
         if CODE_HIGHLIGHT:
-            fmt = highlight(repr(var), PythonLexer(), HtmlFormatter(nowrap=True))
+            fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True))
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
         return f'<code class="lang-python">{str(var)}</code>'
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
         return f'<code class="lang-python">{str(var)}</code>'
@@ -276,8 +289,9 @@ class TextLogger(BaseLogger):
         self._write(f'TOTAL: {total}\n\n')
 
 
         self._write(f'TOTAL: {total}\n\n')
 
 
-    def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
-        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
+    def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
+            params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
+        sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
         self._write(f"| {sqlstr}\n\n")
 
 
         self._write(f"| {sqlstr}\n\n")