Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 94 additions & 5 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import typing
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Optional, Sequence, Type

import duckdb
from packaging.version import Version
Expand All @@ -18,7 +18,7 @@
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes, type_api
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import BigInteger, Integer, SmallInteger, String
from sqlalchemy.types import ARRAY, BigInteger, Integer, SmallInteger, String

# INTEGER INT4, INT, SIGNED -2147483648 2147483647
# SMALLINT INT2, SHORT -32768 32767
Expand Down Expand Up @@ -119,7 +119,7 @@ class Struct(TypeEngine):

Table(
'hello',
Column('name', Struct({'first': String, 'last': String})
Column('name', Struct({'first': String, 'last': String}))
)
```

Expand All @@ -142,7 +142,7 @@ class Map(TypeEngine):

Table(
'hello',
Column('name', Map(String, String)
Column('name', Map(String, String))
)
```
"""
Expand Down Expand Up @@ -183,7 +183,7 @@ class Union(TypeEngine):

Table(
'hello',
Column('name', Union({"name": String, "age": String})
Column('name', Union({"name": String, "age": String}))
)
```
"""
Expand All @@ -195,6 +195,81 @@ def __init__(self, fields: Dict[str, TV]):
self.fields = fields


class List(ARRAY):
"""
Represents a LIST type in DuckDB

```python
from duckdb_engine.datatypes import List
from sqlalchemy import Table, Column, String

Table(
'hello',
Column('name', List(String))
)
```
"""

__visit_name__ = "list"
item_type: TV

def __init__(
self,
item_type: TV,
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
):
super().__init__(
item_type,
as_tuple=as_tuple,
dimensions=dimensions,
zero_indexes=zero_indexes,
)


class Array(ARRAY):
"""
Represents a ARRAY type in DuckDB

```python
from duckdb_engine.datatypes import Array
from sqlalchemy import Table, Column, String

Table(
'hello',
Column('name', Array(String, 3))
)
```
"""

__visit_name__ = "array"
item_type: TV

def __init__(
self,
item_type: TV,
size: typing.Union[int, Sequence[int]],
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
):
super().__init__(
item_type,
as_tuple=as_tuple,
dimensions=dimensions,
zero_indexes=zero_indexes,
)
if isinstance(size, int):
size = [size]
self.size = size
if self.dimensions is None and self.size:
self.dimensions = len(self.size)
if len(self.size) != self.dimensions:
msg = f"length of size must be equal to dimensions ({len(self.size)} != {self.dimensions})"
raise ValueError(msg)


ISCHEMA_NAMES = {
"hugeint": HugeInteger,
"uhugeint": UHugeInteger,
Expand Down Expand Up @@ -282,3 +357,17 @@ def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str:
process_type(instance.key_type, compiler, **kw),
process_type(instance.value_type, compiler, **kw),
)


@compiles(List, "duckdb") # type: ignore[misc]
def visit_list(instance: List, compiler: PGTypeCompiler, **kw: Any) -> str:
return process_type(instance.item_type, compiler, **kw) + "[]" * (
instance.dimensions or 1
)


@compiles(Array, "duckdb") # type: ignore[misc]
def visit_array(instance: Array, compiler: PGTypeCompiler, **kw: Any) -> str:
return process_type(instance.item_type, compiler, **kw) + "".join(
f"[{n}]" for n in instance.size
)
19 changes: 16 additions & 3 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sqlalchemy.types import FLOAT, JSON

from .._supports import duckdb_version, has_uhugeint_support
from ..datatypes import Map, Struct, types
from ..datatypes import Array, List, Map, Struct, types


@mark.parametrize("coltype", types)
Expand Down Expand Up @@ -196,19 +196,32 @@ class Entry(base):
struct = Column(Struct(fields={"name": String}))
map = Column(Map(String, Integer))
# union = Column(Union(fields={"name": String, "age": Integer}))
array = Column(Array(String, 3))
list = Column(List(String))

base.metadata.create_all(bind=engine)

struct_data = {"name": "Edgar"}
map_data = {"one": 1, "two": 2}

session.add(Entry(struct=struct_data, map=map_data)) # type: ignore[call-arg]
array_data = ["one", "two", "three"]
list_data = ["one", "two", "three"]

session.add(
Entry(
struct=struct_data, # type: ignore[call-arg]
map=map_data,
array=array_data,
list=list_data,
)
)
session.commit()

result = session.query(Entry).one()

assert result.struct == struct_data
assert result.map == map_data
assert result.array == array_data
assert result.list == list_data


def test_double_nested_types(engine: Engine, session: Session) -> None:
Expand Down