Commit 22ab6c65 authored by Christophe Benz's avatar Christophe Benz
Browse files

Rewrite query compilation to support "IN"

parent f42b835f
......@@ -12,6 +12,7 @@ import sqlalchemy as sa
from doltpy.cli import Dolt, DoltException
from doltpy.sql import DoltSQLContext
from sqlalchemy.sql import text
from toolz import valmap
from .csv_utils import write_csv
from .typez import TableImportMode
......@@ -140,14 +141,33 @@ class DoltClient:
process.wait()
def compile_query_with_params(query: str, params: Optional[dict] = None):
def compile_query_with_params(query: str, params: Optional[dict] = None) -> str:
"""Return the query with its params substituted.
>>> compile_query_with_params("SELECT * FROM series")
'SELECT * FROM series'
>>> compile_query_with_params("SELECT * FROM series WHERE series_code = :series_code", {"series_code": "A"})
"SELECT * FROM series WHERE series_code = 'A'"
>>> compile_query_with_params("SELECT * FROM series WHERE series_code IN :series_code", {"series_code": ["A"]})
"SELECT * FROM series WHERE series_code IN ('A')"
>>> compile_query_with_params("SELECT * FROM series WHERE series_code IN :series_code", {"series_code": ["A", "B"]})
"SELECT * FROM series WHERE series_code IN ('A','B')"
"""
if params is None:
return query
statement = text(query)
statement = statement.bindparams(**params)
# Dangerous! We do this only because we don't use extenal input as parameters.
return str(statement.compile(compile_kwargs={"literal_binds": True}))
def quote(s):
return f"'{s}'"
for k, v in params.items():
if isinstance(v, (set, list)):
v = "({})".format(",".join(map(quote, v)))
elif isinstance(v, str):
v = quote(v)
assert isinstance(v, str)
query = query.replace(f":{k}", v)
return query
def iter_csv_rows(fieldnames: list[str], rows: Iterable[dict], write_header: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment