]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/api/logging.py
Merge pull request #3383 from lonvia/window-searches
[nominatim.git] / nominatim / api / logging.py
index 37ae7f5f04464241ad0e81062b56d125555cadff..30999a3f31282a085520baabf1a6b308f9d6ed7b 100644 (file)
@@ -13,6 +13,7 @@ import datetime as dt
 import textwrap
 import io
 import re
 import textwrap
 import io
 import re
+import html
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
 
 import sqlalchemy as sa
 from sqlalchemy.ext.asyncio import AsyncConnection
@@ -83,33 +84,49 @@ class BaseLogger:
     def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
                    extra_params: Union[Mapping[str, Any],
                                  Sequence[Mapping[str, Any]], None]) -> str:
     def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
                    extra_params: Union[Mapping[str, Any],
                                  Sequence[Mapping[str, Any]], None]) -> str:
-        """ Return the comiled version of the statement.
+        """ Return the compiled version of the statement.
         """
         compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
 
         params = dict(compiled.params)
         if isinstance(extra_params, Mapping):
             for k, v in extra_params.items():
         """
         compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
 
         params = dict(compiled.params)
         if isinstance(extra_params, Mapping):
             for k, v in extra_params.items():
-                params[k] = str(v)
+                if hasattr(v, 'to_wkt'):
+                    params[k] = v.to_wkt()
+                elif isinstance(v, (int, float)):
+                    params[k] = v
+                else:
+                    params[k] = str(v)
         elif isinstance(extra_params, Sequence) and extra_params:
             for k in extra_params[0]:
                 params[k] = f':{k}'
 
         sqlstr = str(compiled)
 
         elif isinstance(extra_params, Sequence) and extra_params:
             for k in extra_params[0]:
                 params[k] = f':{k}'
 
         sqlstr = str(compiled)
 
-        if sa.__version__.startswith('1'):
-            try:
-                sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
-                return sqlstr % tuple((repr(params.get(name, None))
-                                      for name in compiled.positiontup)) # type: ignore
-            except TypeError:
-                return sqlstr
+        if conn.dialect.name == 'postgresql':
+            if sa.__version__.startswith('1'):
+                try:
+                    sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
+                    return sqlstr % tuple((repr(params.get(name, None))
+                                          for name in compiled.positiontup)) # type: ignore
+                except TypeError:
+                    return sqlstr
 
 
-        # Fixes an odd issue with Python 3.7 where percentages are not
-        # quoted correctly.
-        sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
-        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
-        return sqlstr % params
+            # Fixes an odd issue with Python 3.7 where percentages are not
+            # quoted correctly.
+            sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
+            sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
+            return sqlstr % params
+
+        assert conn.dialect.name == 'sqlite'
+
+        # params in positional order
+        pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore
+
+        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr)
+        sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr)
+
+        return sqlstr
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
@@ -211,7 +228,7 @@ class HTMLLogger(BaseLogger):
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
             self._write(f'<div class="highlight"><code class="lang-sql">{sqlstr}</code></div>')
         else:
                                HtmlFormatter(nowrap=True, lineseparator='<br />'))
             self._write(f'<div class="highlight"><code class="lang-sql">{sqlstr}</code></div>')
         else:
-            self._write(f'<code class="lang-sql">{sqlstr}</code>')
+            self._write(f'<code class="lang-sql">{html.escape(sqlstr)}</code>')
 
 
     def _python_var(self, var: Any) -> str:
 
 
     def _python_var(self, var: Any) -> str:
@@ -219,7 +236,7 @@ class HTMLLogger(BaseLogger):
             fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True))
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
             fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True))
             return f'<div class="highlight"><code class="lang-python">{fmt}</code></div>'
 
-        return f'<code class="lang-python">{str(var)}</code>'
+        return f'<code class="lang-python">{html.escape(str(var))}</code>'
 
 
     def _write(self, text: str) -> None:
 
 
     def _write(self, text: str) -> None: