diff --git a/docs/advanced/change-default-type-mapping.md b/docs/advanced/change-default-type-mapping.md new file mode 100644 index 0000000000..a9bb9f299e --- /dev/null +++ b/docs/advanced/change-default-type-mapping.md @@ -0,0 +1,70 @@ +# Change default Pydantic to SQLAlchemy mappings + +In most cases, you do not need to know how SQLAlchemy transforms the Python types to the type suitable for storing data +in the database, and you can use the default mapping. + +But in some cases you may need to have possibility to change default mapping provided by SQLmodel. For example to use +mssql dialect with some UTF-8 data you should use NVARCHAR field (sa.Unicode) + +Now changing default mapping is simple to use - see example bellow: + +```python +import sqlmodel.main +import sqlalchemy as sa +from sqlmodel import Field, SQLModel +from typing import Optional + +sqlmodel.main.sa_types_map[str] = lambda type_, meta, annotation: sa.Unicode( + length=getattr(meta, "max_length", None) +) + + +class Hero(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + name: str = Field(max_length=255) + history: Optional[str] + + +assert isinstance(Hero.name.type, sa.Unicode) +``` + +# Some details + +Let's get little deeper to process of mapping. At the `sqlmodel.main` module defined the `sa_types_map` dictionary, +which uses the Python types as keys, and the sqlalchemy type or callable that takes the input of 3 parameters and +returns the sqlalchemy type as values. + +Callable format present bellow: + +```python +def map_python_type_to_sa_type(type_: "PythonType", meta: "PydanticMeta", annotation: "FieldAnnotatedType"): + return sqlalchemyType(length=getattr(meta, "max_length", None)) +``` + +* `type_` - used to pass python type, provided by pydantic annotation, cleared from Union/Optional and other wrappers. + Can be passed to sa.Enum type to properly store enumerated data. +* `meta` - pydantic metadata used to store field params e.g. length of str field or precision of decimal field +* `annotation` - original annotation given by pydantic. Used to provide `type` parameter for PydanticJSONType + +## Current mapping + +| Python type | SqlAlchemy type | +|-----------------------|------------------------------------------------------------------------------------------------------------------------------------| +| Enum | `lambda type_, meta, annotation: sa_Enum(type_)` | +| str | `lambda type_, meta, annotation: AutoString(length=getattr(meta, "max_length", None))` | +| float | Float | +| bool | Boolean | +| int | Integer | +| datetime | DateTime | +| date | Date | +| timedelta | Interval | +| time | Time | +| bytes | LargeBinary | +| Decimal | `lambda type_, meta, annotation: Numeric(precision=getattr(meta, "max_digits", None),scale=getattr(meta, "decimal_places", None))` | +| ipaddress.IPv4Address | AutoString | +| ipaddress.IPv4Network | AutoString | +| ipaddress.IPv6Address | AutoString | +| ipaddress.IPv6Network | AutoString | +| Path | AutoString | +| uuid.UUID | Uuid | +| BaseModel | `lambda type_, meta, annotation: PydanticJSONType(type=annotation)` | diff --git a/docs/advanced/index.md b/docs/advanced/index.md index f6178249ce..b034574fba 100644 --- a/docs/advanced/index.md +++ b/docs/advanced/index.md @@ -4,7 +4,7 @@ The **Advanced User Guide** is gradually growing, you can already read about som At some point it will include: -* How to use `async` and `await` with the async session. + * How to use `async` and `await` with the async session. * How to run migrations. * How to combine **SQLModel** models with SQLAlchemy. * ...and more. πŸ€“ diff --git a/docs/advanced/pydantic-json-type.md b/docs/advanced/pydantic-json-type.md new file mode 100644 index 0000000000..ae54a26d26 --- /dev/null +++ b/docs/advanced/pydantic-json-type.md @@ -0,0 +1,60 @@ +# Storing Pydantic models at database + +In some cases you might need to be able to store Pydantic models as a JSON data instead of create new table for them. You can do it now with new SqlAlchemy type `PydanticJSonType` mapped to BaseModel inside SqlModel. + +For example let's add some stats to our heroes and save them at database as JSON data. + +At first, we need to create new class `Stats` inherited from pydantic `BaseModel` or even `SqlModel`: + +```{.python .annotate } +{!./docs_src/advanced/pydantic_json_type/tutorial001.py[ln:8-14]!} +``` + +Then create new field `stats` to `Hero` model + +```{.python .annotate hl_lines="6" } +{!./docs_src/advanced/pydantic_json_type/tutorial001.py[ln:17-22]!} +``` +And... that's all of you need to do to store pydantic data as JSON at database. + +/// details | πŸ‘€ Full tutorial preview +```Python +{!./docs_src/advanced/pydantic_json_type/tutorial001.py!} +``` + +/// + +Here we define new Pydantic model `Stats` contains statistics of our hero and map this model to SqlModel class `Hero`. + +Then we create new instances of Hero model with random generated stats and save it at database. + + +# How to watch for mapped model changes at runtime + +In previous example we have one *non bug but feature* - `Stats` model isn't mutable and if we try to load our Hero form database and then change some stats and call `session.commit()` there no changes will be saved. + +Let's see how to avoid it. + +At first, we need to inherit our Stats model from `sqlalchemy.ext.mutable.Mutable`: +```{.python .annotate hl_lines="1" } +{!./docs_src/advanced/pydantic_json_type/tutorial002.py[ln:10-19]!} +``` + +Then map Stats to Hero as shown bellow: +```{.python .annotate hl_lines="1-4" } +{!./docs_src/advanced/pydantic_json_type/tutorial002.py[ln:36-39]!} +``` + +After all of these actions we can change mutated model, and it will be saved to database after we call `session.commit()` + +```{.python .annotate hl_lines="4" } +{!./docs_src/advanced/pydantic_json_type/tutorial002.py[ln:76-94]!} +``` + +/// details | πŸ‘€ Full tutorial preview + +```Python +{!./docs_src/advanced/pydantic_json_type/tutorial002.py!} +``` + +/// diff --git a/docs_src/advanced/pydantic_json_type/__init__.py b/docs_src/advanced/pydantic_json_type/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs_src/advanced/pydantic_json_type/tutorial001.py b/docs_src/advanced/pydantic_json_type/tutorial001.py new file mode 100644 index 0000000000..9b6c5841c6 --- /dev/null +++ b/docs_src/advanced/pydantic_json_type/tutorial001.py @@ -0,0 +1,85 @@ +import random +from typing import Optional + +from pydantic import BaseModel +from sqlmodel import Field, Session, SQLModel, create_engine, select + + +class Stats(BaseModel): + strength: int + dexterity: int + constitution: int + intelligence: int + wisdom: int + charisma: int + + +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] + stats: Optional[Stats] + + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +engine = create_engine(sqlite_url, echo=True) + + +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) + + +def random_stat(): + random.seed() + + return Stats( + strength=random.randrange(1, 20, 2), + dexterity=random.randrange(1, 20, 2), + constitution=random.randrange(1, 20, 2), + intelligence=random.randrange(1, 20, 2), + wisdom=random.randrange(1, 20, 2), + charisma=random.randrange(1, 20, 2), + ) + + +def create_heroes(): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson", stats=random_stat()) + hero_2 = Hero( + name="Spider-Boy", secret_name="Pedro Parqueador", stats=random_stat() + ) + hero_3 = Hero( + name="Rusty-Man", secret_name="Tommy Sharp", age=48, stats=random_stat() + ) + + with Session(engine) as session: + session.add(hero_1) + session.add(hero_2) + session.add(hero_3) + + session.commit() + + +def select_heroes(): + with Session(engine) as session: + statement = select(Hero).where(Hero.name == "Deadpond") + results = session.exec(statement) + hero_1 = results.one() + print("Hero 1:", hero_1) + + statement = select(Hero).where(Hero.name == "Rusty-Man") + results = session.exec(statement) + hero_2 = results.one() + print("Hero 2:", hero_2) + + +def main(): + create_db_and_tables() + create_heroes() + select_heroes() + + +if __name__ == "__main__": + main() diff --git a/docs_src/advanced/pydantic_json_type/tutorial002.py b/docs_src/advanced/pydantic_json_type/tutorial002.py new file mode 100644 index 0000000000..cbfb5a93b2 --- /dev/null +++ b/docs_src/advanced/pydantic_json_type/tutorial002.py @@ -0,0 +1,102 @@ +import random +from typing import Any, Optional + +from pydantic import BaseModel +from sqlalchemy import Column +from sqlalchemy.ext.mutable import Mutable +from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.sql.sqltypes import PydanticJSONType + + +class Stats(BaseModel, Mutable): + strength: int + dexterity: int + constitution: int + intelligence: int + wisdom: int + charisma: int + + @classmethod + def coerce(cls, key: str, value: Any) -> Optional[Any]: + return value + + def __setattr__(self, key, value): + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] + stats: Stats = Field( + default_factory=None, + sa_column=Column(Stats.as_mutable(PydanticJSONType(type=Stats))), + ) + + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +engine = create_engine(sqlite_url, echo=True) + + +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) + + +def random_stat(): + random.seed() + + return Stats( + strength=random.randrange(1, 20, 2), + dexterity=random.randrange(1, 20, 2), + constitution=random.randrange(1, 20, 2), + intelligence=random.randrange(1, 20, 2), + wisdom=random.randrange(1, 20, 2), + charisma=random.randrange(1, 20, 2), + ) + + +def create_hero(): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson", stats=random_stat()) + + with Session(engine) as session: + session.add(hero_1) + + session.commit() + + +def mutate_hero(): + with Session(engine) as session: + statement = select(Hero).where(Hero.name == "Deadpond") + results = session.exec(statement) + hero_1 = results.one() + + print("Hero 1:", hero_1.stats) + + hero_1.stats.strength = 100500 + session.commit() + + with Session(engine) as session: + statement = select(Hero).where(Hero.name == "Deadpond") + results = session.exec(statement) + hero_1 = results.one() + + print("Hero 1 strength:", hero_1.stats.strength) + + print("Hero 1:", hero_1) + + +def main(): + create_db_and_tables() + create_hero() + mutate_hero() + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index ce98f1524e..3c70f080cc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -98,6 +98,8 @@ nav: - Advanced User Guide: - advanced/index.md - advanced/decimal.md + - advanced/change-default-type-mapping.md + - advanced/pydantic-json-type.md - alternatives.md - help.md - contributing.md diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 2a2caca3e8..e3ee510210 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -17,6 +17,7 @@ Union, ) +from annotated_types import BaseMetadata from pydantic import VERSION as PYDANTIC_VERSION from pydantic.fields import FieldInfo from typing_extensions import get_args, get_origin @@ -190,8 +191,9 @@ def get_type_from_field(field: Any) -> Any: def get_field_metadata(field: Any) -> Any: for meta in field.metadata: - if isinstance(meta, PydanticMetadata): + if isinstance(meta, (PydanticMetadata, BaseMetadata)): return meta + return FakeMetadata() def post_init_field_info(field_info: FieldInfo) -> None: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 10064c7116..17675d19b0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -5,6 +5,7 @@ from decimal import Decimal from enum import Enum from pathlib import Path +from types import FunctionType from typing import ( AbstractSet, Any, @@ -50,7 +51,7 @@ from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData -from sqlalchemy.sql.sqltypes import LargeBinary, Time +from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid from typing_extensions import Literal, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] @@ -78,13 +79,42 @@ sqlmodel_init, sqlmodel_validate, ) -from .sql.sqltypes import GUID, AutoString +from .sql.sqltypes import AutoString, PydanticJSONType _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None] +sa_types_map: Dict[Any, Union[Any, Callable[[Any, Any, Any], Any]]] = { + Enum: lambda type_, meta, annotation: sa_Enum(type_), + str: lambda type_, meta, annotation: AutoString( + length=getattr(meta, "max_length", None) + ), + float: Float, + bool: Boolean, + int: Integer, + datetime: DateTime, + date: Date, + timedelta: Interval, + time: Time, + bytes: LargeBinary, + Decimal: lambda type_, meta, annotation: Numeric( + precision=getattr(meta, "max_digits", None), + scale=getattr(meta, "decimal_places", None), + ), + ( + ipaddress.IPv4Address, + ipaddress.IPv4Network, + ipaddress.IPv6Address, + ipaddress.IPv6Network, + Path, + ): AutoString, + uuid.UUID: Uuid, # Custom GUID type is no longer required at SQLAlchemy v2 + BaseModel: lambda type_, meta, annotation: PydanticJSONType(type=annotation), +} + + def __dataclass_transform__( *, eq_default: bool = True, @@ -556,54 +586,24 @@ def get_sqlalchemy_type(field: Any) -> Any: field_info = field else: field_info = field.field_info + sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + if sa_type is not Undefined: return sa_type type_ = get_type_from_field(field) metadata = get_field_metadata(field) - # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(type_, Enum): - return sa_Enum(type_) - if issubclass(type_, str): - max_length = getattr(metadata, "max_length", None) - if max_length: - return AutoString(length=max_length) - return AutoString - if issubclass(type_, float): - return Float - if issubclass(type_, bool): - return Boolean - if issubclass(type_, int): - return Integer - if issubclass(type_, datetime): - return DateTime - if issubclass(type_, date): - return Date - if issubclass(type_, timedelta): - return Interval - if issubclass(type_, time): - return Time - if issubclass(type_, bytes): - return LargeBinary - if issubclass(type_, Decimal): - return Numeric( - precision=getattr(metadata, "max_digits", None), - scale=getattr(metadata, "decimal_places", None), - ) - if issubclass(type_, ipaddress.IPv4Address): - return AutoString - if issubclass(type_, ipaddress.IPv4Network): - return AutoString - if issubclass(type_, ipaddress.IPv6Address): - return AutoString - if issubclass(type_, ipaddress.IPv6Network): - return AutoString - if issubclass(type_, Path): - return AutoString - if issubclass(type_, uuid.UUID): - return GUID + for expected_type, sa_type in sa_types_map.items(): + if not issubclass(type_, expected_type): + continue + + if not isinstance(sa_type, FunctionType): + return sa_type + + return sa_type(type_, metadata, field.annotation) + raise ValueError(f"{type_} has no matching SQLAlchemy type") @@ -869,3 +869,32 @@ def _calculate_keys( exclude_unset=exclude_unset, update=update, ) + + if IS_PYDANTIC_V2: + + def __eq__(self, other: Any) -> bool: + # Should override comparison method because of the base method uses some additional + # pydantic related logic e.g. comparison of __pydantic_private__ and __pydantic_extra__ + # of models presented. This fields only available if __init__ was called and missed if + # a model instantiated from database + # Ref: https://github.com/pydantic/pydantic/commit/62ed0aabd716fba99a5b1cc0cdc26d75d2d8447a + if isinstance(other, SQLModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = ( + self.__pydantic_generic_metadata__["origin"] or self.__class__ + ) + other_type = ( + other.__pydantic_generic_metadata__["origin"] or other.__class__ + ) + + dict1 = self.__dict__.copy() + dict2 = other.__dict__.copy() + + _ = dict1.pop("_sa_instance_state", None) + _ = dict2.pop("_sa_instance_state", None) + + return self_type == other_type and dict1 == dict2 + else: + return super().__eq__(other) diff --git a/sqlmodel/sql/pydantic_v1_json.py b/sqlmodel/sql/pydantic_v1_json.py new file mode 100644 index 0000000000..d4f6c81d4d --- /dev/null +++ b/sqlmodel/sql/pydantic_v1_json.py @@ -0,0 +1,57 @@ +"""ΠšΠ°ΡΡ‚ΠΎΠΌΠ½Ρ‹ΠΉ Ρ‚ΠΈΠΏ для сСриализации/дСсСриализации JSON""" + +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, parse_obj_as +from sqlalchemy import types +from sqlalchemy.engine import Dialect + + +class PydanticJSONv1(types.TypeDecorator): # type: ignore + """Type to store Pydantic models as JSON data at database""" + + impl = types.JSON + cache_ok = True + + def __init__(self, *args: Any, **kwargs: Any): + """ """ + if "type" not in kwargs: + raise AttributeError( + "Provide 'type' kwarg inherited from BaseModel or List of BaseModel or UnionType" + ) + + self.__type = kwargs.pop("type") + self.__ensure_ascii = kwargs.pop("ensure_ascii", False) + + super().__init__(*args, **kwargs) + + def process_result_value(self, value: Any, dialect: Dialect) -> Any: + if value is None: + return None + return parse_obj_as(self.__type, value) + + @staticmethod + def __pydantic_model_to_json(model: BaseModel, ensure_ascii: bool) -> str: + return model.json(by_alias=True, exclude_none=True, ensure_ascii=ensure_ascii) + + def serialize( + self, value: Optional[Union[Dict[Any, Any], List[BaseModel], BaseModel]] + ) -> Any: + if value is None: + return None + + if isinstance(value, list): + values = [ + self.__pydantic_model_to_json(v, self.__ensure_ascii) + for v in parse_obj_as(self.__type, value) + ] + return f"[{','.join(values)}]" + + return self.__pydantic_model_to_json( + parse_obj_as(self.__type, value), self.__ensure_ascii + ) + + def bind_processor(self, dialect: Dialect) -> Optional[Any]: + string_process = self._str_impl.bind_processor(dialect) + + return self._make_bind_processor(string_process, self.serialize) diff --git a/sqlmodel/sql/pydantic_v2_json.py b/sqlmodel/sql/pydantic_v2_json.py new file mode 100644 index 0000000000..dc9e9bd3da --- /dev/null +++ b/sqlmodel/sql/pydantic_v2_json.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, TypeAdapter +from sqlalchemy import types +from sqlalchemy.engine.interfaces import Dialect + + +class PydanticJSONv2(types.TypeDecorator): # type: ignore + """Type to store Pydantic models as JSON data at database""" + + impl = types.JSON + cache_ok = False + + def __init__(self, *args: Any, **kwargs: Any): + """ """ + if "type" not in kwargs: + raise AttributeError( + "Provide 'type' kwarg inherited from BaseModel or List of BaseModel or UnionType" + ) + + self._type = TypeAdapter(kwargs.pop("type")) + + super().__init__(*args, **kwargs) + + def process_result_value(self, value: Any, dialect: Dialect) -> Any: + if value is None: + return None + + return self._type.validate_python(value) + + def serialize( + self, value: Optional[Union[Dict[Any, Any], List[BaseModel], BaseModel]] + ) -> Optional[str]: + if value is None: + return None + + return self._type.dump_json( + self._type.validate_python(value), by_alias=True, exclude_none=True + ).decode() + + def bind_processor(self, dialect: Dialect) -> Optional[Any]: + string_process = self._str_impl.bind_processor(dialect) + + return self._make_bind_processor(string_process, self.serialize) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 5a4bb04ef1..fba3500c5c 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,10 +1,16 @@ -import uuid -from typing import Any, Optional, cast +from typing import Any, cast -from sqlalchemy import CHAR, types -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy import Uuid, types from sqlalchemy.engine.interfaces import Dialect -from sqlalchemy.sql.type_api import TypeEngine +from sqlmodel._compat import IS_PYDANTIC_V2 +from typing_extensions import deprecated + +if IS_PYDANTIC_V2: + from .pydantic_v2_json import PydanticJSONv2 as PydanticJSON +else: + from .pydantic_v1_json import PydanticJSONv1 as PydanticJSON # type: ignore + +PydanticJSONType = PydanticJSON class AutoString(types.TypeDecorator): # type: ignore @@ -12,48 +18,13 @@ class AutoString(types.TypeDecorator): # type: ignore cache_ok = True mysql_default_length = 255 - def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": + def load_dialect_impl(self, dialect: Dialect) -> types.TypeEngine[Any]: impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) -# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type -# with small modifications -class GUID(types.TypeDecorator): # type: ignore - """Platform-independent GUID type. - - Uses PostgreSQL's UUID type, otherwise uses - CHAR(32), storing as stringified hex values. - - """ - - impl = CHAR - cache_ok = True - - def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: - if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) - else: - return dialect.type_descriptor(CHAR(32)) - - def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: - if value is None: - return value - elif dialect.name == "postgresql": - return str(value) - else: - if not isinstance(value, uuid.UUID): - return uuid.UUID(value).hex - else: - # hexstring - return value.hex - - def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]: - if value is None: - return value - else: - if not isinstance(value, uuid.UUID): - value = uuid.UUID(value) - return cast(uuid.UUID, value) +@deprecated("SQlAlchemy V2 has native support of UUID type - sa.Uuid") +class GUID(Uuid): # type: ignore + pass diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index ac4aa42e05..3e9dbd932c 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -1,12 +1,11 @@ from typing import Optional import pytest -from pydantic import BaseModel from sqlmodel import Field, SQLModel def test_missing_sql_type(): - class CustomType(BaseModel): + class CustomType: @classmethod def __get_validators__(cls): yield cls.validate diff --git a/tests/test_python_types_to_sa_types_mapping.py b/tests/test_python_types_to_sa_types_mapping.py new file mode 100644 index 0000000000..0a5dbcee8e --- /dev/null +++ b/tests/test_python_types_to_sa_types_mapping.py @@ -0,0 +1,163 @@ +import copy +import datetime +import ipaddress +from decimal import Decimal +from pathlib import Path +from uuid import UUID + +import sqlalchemy as sa +import sqlmodel.main +from pydantic import BaseModel +from sqlalchemy.engine.mock import MockConnection +from sqlmodel import AutoString, Field, SQLModel +from sqlmodel.sql.sqltypes import PydanticJSONType + + +def get_engine(url: str) -> MockConnection: + engine = sa.create_mock_engine( + url, lambda sql, *args, **kwargs: print(sql.compile(dialect=engine.dialect)) + ) + + return engine + + +def test_string_length_constraint(clear_sqlmodel, capsys): + class StrTest(SQLModel, table=True): + id: str = Field(default=None, primary_key=True, max_length=10) + + SQLModel.metadata.create_all(get_engine("sqlite://")) + captured = capsys.readouterr() + + assert "id VARCHAR(10) NOT NULL," in captured.out + + +def test_native_uuid_column_mapping(clear_sqlmodel, capsys): + class UuidTest(SQLModel, table=True): + id: UUID = Field(default=None, primary_key=True) + + assert isinstance(UuidTest.id.type, sa.Uuid) + + SQLModel.metadata.create_all(get_engine("sqlite://")) + captured = capsys.readouterr() + + assert "id CHAR(32) NOT NULL" in captured.out + + SQLModel.metadata.create_all(get_engine("postgresql://")) + captured = capsys.readouterr() + + assert "id UUID NOT NULL" in captured.out + + SQLModel.metadata.create_all(get_engine("mssql://")) + captured = capsys.readouterr() + + assert "id UNIQUEIDENTIFIER NOT NULL" in captured.out + + SQLModel.metadata.create_all(get_engine("mysql://")) + captured = capsys.readouterr() + + assert "id CHAR(32) NOT NULL" in captured.out + + +def test_default_sa_types_to_python_mapping_is_correct(clear_sqlmodel, capsys): + class Parent(BaseModel): + name: str + sex: str + birth_date: datetime.date + + class Hero(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + global_hero_id: UUID + name: str + birth_date: datetime.date + birth_time: datetime.time + created_at: datetime.datetime + last_seen_delta: datetime.timedelta + salary: Decimal + speed: float + is_on_vacation: bool + dna_marker_data: bytes + headquarter_ip_v4: ipaddress.IPv4Address + headquarter_ip_v6: ipaddress.IPv6Address + shared_folder_path: Path + mother: Parent + father: Parent + + SQLModel.metadata.create_all(get_engine("sqlite://")) + captured = capsys.readouterr() + + assert isinstance(Hero.id.type, sa.Integer) + assert "id INTEGER NOT NULL" in captured.out + + assert isinstance(Hero.global_hero_id.type, sa.Uuid) + assert "global_hero_id CHAR(32) NOT NULL" in captured.out + + assert isinstance(Hero.name.type, AutoString) + assert "name VARCHAR NOT NULL" in captured.out + + assert isinstance(Hero.birth_date.type, sa.Date) + assert "birth_date DATE NOT NULL" in captured.out + + assert isinstance(Hero.birth_time.type, sa.Time) + assert "birth_time TIME NOT NULL" in captured.out + + assert isinstance(Hero.created_at.type, sa.DateTime) + assert "created_at DATETIME NOT NULL" in captured.out + + assert isinstance(Hero.last_seen_delta.type, sa.Interval) + assert "created_at DATETIME NOT NULL" in captured.out + + assert isinstance(Hero.salary.type, sa.Numeric) + assert "salary NUMERIC NOT NULL" in captured.out + + assert isinstance(Hero.speed.type, sa.Float) + assert "speed FLOAT NOT NULL" in captured.out + + assert isinstance(Hero.is_on_vacation.type, sa.Boolean) + assert "is_on_vacation BOOLEAN NOT NULL" in captured.out + + assert isinstance(Hero.dna_marker_data.type, sa.LargeBinary) + assert "dna_marker_data BLOB NOT NULL" in captured.out + + assert isinstance(Hero.headquarter_ip_v4.type, AutoString) + assert "headquarter_ip_v4 VARCHAR NOT NULL" in captured.out + + assert isinstance(Hero.headquarter_ip_v6.type, AutoString) + assert "headquarter_ip_v6 VARCHAR NOT NULL" in captured.out + + assert isinstance(Hero.shared_folder_path.type, AutoString) + assert "shared_folder_path VARCHAR NOT NULL" in captured.out + + assert isinstance(Hero.mother.type, PydanticJSONType) + assert "mother JSON NOT NULL" in captured.out + + assert isinstance(Hero.father.type, PydanticJSONType) + assert "father JSON NOT NULL" in captured.out + + +def test_default_sa_type_mapping_change(clear_sqlmodel, capsys): + base_map = copy.deepcopy(sqlmodel.main.sa_types_map) + sqlmodel.main.sa_types_map[str] = lambda type_, meta, annotation: sa.Unicode( + length=getattr(meta, "max_length", None) + ) + + class Hero(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + name: str = Field(max_length=255) + history: str + + assert str(Hero.name.type) == str(sa.Unicode(255)) + assert str(Hero.history.type) == str(sa.Unicode()) + + SQLModel.metadata.create_all(get_engine("mssql://")) + captured = capsys.readouterr() + + assert "name NVARCHAR(255) NOT NULL" in captured.out + assert "history NVARCHAR(max) NOT NULL" in captured.out + + SQLModel.metadata.create_all(get_engine("sqlite://")) + captured = capsys.readouterr() + + assert "name VARCHAR(255) NOT NULL" in captured.out + assert "history VARCHAR NOT NULL" in captured.out + + sqlmodel.main.sa_types_map = base_map diff --git a/tests/test_sqlmodel_instance_equality.py b/tests/test_sqlmodel_instance_equality.py new file mode 100644 index 0000000000..d145a8b8c1 --- /dev/null +++ b/tests/test_sqlmodel_instance_equality.py @@ -0,0 +1,46 @@ +from datetime import datetime +from typing import Optional + +from sqlalchemy import create_engine +from sqlmodel import Field, Session, SQLModel, select + + +def test_created_and_instantiated_from_db_instances_are_equal(clear_sqlmodel): + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + username: str + email: str = "test@test.com" + last_updated: datetime = Field(default_factory=datetime.now) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + user_1 = User(username="test_user") + user_2 = User(username="test_user_2") + + assert user_1 != user_2 + + with Session(engine) as session: + session.add(user_1) + session.add(user_2) + session.commit() + + session.refresh(user_1) + session.refresh(user_2) + + assert user_1 != user_2 + + with Session(engine) as session: + session.merge(user_1) + session.merge(user_2) + + instantiated_user_1 = session.exec( + select(User).where(User.username == user_1.username) + ).one() + + instantiated_user_2 = session.exec( + select(User).where(User.username == user_2.username) + ).one() + + assert user_1 != user_2 + assert user_1 == instantiated_user_1 + assert user_2 == instantiated_user_2