Typed SQL Client & Query DSL in Python
The current SQL database clients in the Python ecosystem either lack static
types completely or use Any
in their return types.
In this post I’m going to explore an API for a statically typed db client.
Plain SQL w/ Shape Argument
Using strings with our SQL client is the most basic and straightforward approach, but composing SQL strings is tricky.
result = db.sql(
"SELECT id, name FROM table_name;", shape=List[Tuple[int, str]]
)
Query Builder w/ Shape Argument
Related: SQLAlchemy Core, deuterium, ObjectiveSql
Instead of using strings, it would be better if we had some in language DSL for defining queries, like SQLAlchemy core, but with static typing, to make queries easier to manipulate.
Defining the table schema is easy enough in Python with static types:
class Post:
id = BigSerial(primary_key=True)
name = Text()
description = Text()
created = Datetime()
karma = Integer()
And then we can use the class attributes in the query DSL:
query = (
select(Post.id, Post.name)
.where((Post.id >= 25) & (Post.name == "foo"))
.where(Post.karma >= 25)
.order_by(Post.name, desc(Post.description))
.limit(5)
.skip(10)
)
if has_lots_of_karma:
# easier than messing around with SQL strings
query = query.where(Post.karma > 10_000)
result = db.query(query, shape=List[Tuple[int, str]])
Compared to the strings, SQL executed with the query DSL isn’t always obvious, so we can watch the database logs or have the db client write the generated queries to the console.
Query Builder w/o Specific Shape Argument
The query DSL approach is more flexible than the string approach, but can we remove the need to provide a type entirely?
Django’s ORM uses the table definition class for the output type of the queries, which alleviates the need to specify the query’s result type.
However, things get a little tricky once you select specific columns with
.only()
as the same table class is used, with the missing fields replaced by their
zero/empty values.
# only Post.id, and Post.name are populated with actual values.
# Post.karam, Post.description, and Post.created are set to their empty
# values.
result = Post.objects.filter(karam__gte=10).only("id", "name")
While this functionality can be confusing, it type checks, but the pattern
doesn’t transfer to more complicated queries that can’t be expressed in the
ORM. For those we’ll need to use the underlying psycopg
cursor
with SQL strings.
So how can we avoid having to write out the type of the query result?
We could try making the argument generic with a
TypeVar
so
that select(T)
returns Query[T]
.
The problem with this approach is that the types are wrong. In
select(Post.id, Post.name)
the types of Post.id
and Post.name
aren’t
int
and str
, but instead BigSerial
and Text
. And we can’t change
these types because we use them for building queries.
Another option is adding @overload
s for the select()
’s __init__
, but
that would require an overload for every possible argument count, along with
every argument type used, as we need the overloads to map from the query builder
types, BigSerial
, Text
, Integer
, etc. to their corresponding Python types.
Additionally we’d require even more overloads for things like
json_agg
.
We might be able to cover the functionality with a mypy plugin, but this wouldn’t transfer to other type checkers like Pyright.
Code gen might work. We could generate the overloads based off the usage in the project, but that seems tricky to implement.
Related
Prisma is a TypeScript ORM that correctly
types the return result when selecting specific
fields
a la Django ORM’s .only()
method, but the client is limited in
functionality, so for more more advanced queries you’ll need to use
prisma.$queryRaw()
which returns any
.
Essentially, Prisma is a better typed Django ORM, but doesn’t support typing the returns of arbitrary SQL queries.
Conclusion
It seems the best we can do is a query DSL with a specific shape argument to the db client.
Code
Basic stubs for the API:
from __future__ import annotations
from typing import Any, Iterable, List, Type, TypeVar, Union, Tuple
from decimal import Decimal
class Field:
def __ge__(self, other: Union[Any, Field]) -> ComparisionResult:
...
class ComparisionResult:
def __and__(self, other: ComparisionResult) -> ComparisionResult:
...
class Integer(Field):
def __init__(self, *, primary_key: bool = False) -> None:
...
def __ge__(self, other: Union[int, Field]) -> ComparisionResult:
...
def __gt__(self, other: Union[int, Field]) -> ComparisionResult:
...
class BigSerial(Field):
def __init__(self, *, primary_key: bool = False) -> None:
...
def __ge__(self, other: Union[int, Field]) -> ComparisionResult:
...
class Text(Field):
def __eq__(self, other: Union[str, Field]) -> ComparisionResult:
...
def __add__(self, other: Any) -> ComparisionResult:
...
class Datetime(Field):
pass
class Post:
id = BigSerial(primary_key=True)
name = Text()
description = Text()
created = Datetime()
karma = Integer()
class Query:
def __init__(self, *args: Field) -> None:
...
def where(self, *args: ComparisionResult) -> Query:
...
def order_by(self, *args: Union[Field, desc]) -> Query:
...
def limit(self, count: int) -> Query:
...
def count(self) -> Query:
...
def skip(self, count: int) -> Query:
...
def returning(self, *args: Any) -> Query:
...
def label(self, name: str) -> Query:
...
class desc:
def __init__(self, *args: Field) -> None:
...
def and_(*args: ComparisionResult) -> ComparisionResult:
...
T = TypeVar("T")
class Postgres:
def query(self, query: Query, shape: Type[T]) -> T:
...
def sql(
self,
query: str,
*,
args: Iterable[Union[bool, str, int, float, Decimal, None, bytes]] = (),
shape: Type[T],
) -> T:
...
select = Query
def example_1(db: Postgres) -> None:
result = db.sql("SELECT id, name FROM table_name;", shape=List[Tuple[int, str]])
def example_2(db: Postgres, has_lots_of_karma: bool) -> None:
query = (
select(Post.id, Post.name)
.where((Post.id >= 25) & (Post.name == "foo"))
.where(Post.karma >= 25)
.order_by(Post.name, desc(Post.description))
.limit(5)
.skip(10)
)
if has_lots_of_karma:
# easier than messing around with SQL strings
query = query.where(Post.karma > 10000)
result = db.query(query, shape=List[Tuple[int, str]])