diff --git a/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py b/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py new file mode 100644 index 0000000000..a4584836f7 --- /dev/null +++ b/docs_src/tutorial/relationship_attributes/define_relationship_attributes/tutorial001_an_py310.py @@ -0,0 +1,70 @@ +from typing import Annotated + +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine + + +class Team(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + headquarters: str + + heroes: Annotated[list["Hero"] | None, Relationship(back_populates="team")] = None + + +class Hero(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + secret_name: str + age: Annotated[int | None, Field(index=True)] = None + + team_id: Annotated[int | None, Field(foreign_key="team.id")] = None + team: Annotated[Team | None, Relationship(back_populates="heroes")] = None + + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +engine = create_engine(sqlite_url, echo=True) + + +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) + + +def create_heroes(): + with Session(engine) as session: + team_preventers = Team(name="Preventers", headquarters="Sharp Tower") + team_z_force = Team(name="Z-Force", headquarters="Sister Margaret's Bar") + + hero_deadpond = Hero( + name="Deadpond", secret_name="Dive Wilson", team=team_z_force + ) + hero_rusty_man = Hero( + name="Rusty-Man", secret_name="Tommy Sharp", age=48, team=team_preventers + ) + hero_spider_boy = Hero(name="Spider-Boy", secret_name="Pedro Parqueador") + session.add(hero_deadpond) + session.add(hero_rusty_man) + session.add(hero_spider_boy) + session.commit() + + session.refresh(hero_deadpond) + session.refresh(hero_rusty_man) + session.refresh(hero_spider_boy) + + print("Created hero:", hero_deadpond) + print("Created hero:", hero_rusty_man) + print("Created hero:", hero_spider_boy) + + hero_spider_boy.team = team_preventers + session.add(hero_spider_boy) + session.commit() + + +def main(): + create_db_and_tables() + create_heroes() + + +if __name__ == "__main__": + main() diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a220b193f1..841fe5744c 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -151,6 +151,9 @@ def get_relationship_to( elif origin is list: use_annotation = get_args(annotation)[0] + elif origin is Annotated: + use_annotation = get_args(annotation)[0] + return get_relationship_to(name=name, rel_info=rel_info, annotation=use_annotation) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9a1a676775..32c1d61881 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -11,6 +11,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Annotated, Any, ClassVar, Literal, @@ -18,6 +19,7 @@ TypeVar, Union, cast, + get_args, get_origin, overload, ) @@ -517,6 +519,16 @@ def Relationship( return relationship_info +def get_annotated_relationshipinfo(t: Any) -> RelationshipInfo | None: + """Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo.""" + if get_origin(t) is not Annotated: + return None + for a in get_args(t): + if isinstance(a, RelationshipInfo): + return a + return None + + @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: dict[str, RelationshipInfo] @@ -549,16 +561,29 @@ def __new__( original_annotations = get_annotations(class_dict) pydantic_annotations = {} relationship_annotations = {} - for k, v in class_dict.items(): + + # find relationship info in both annotations and class dict + for k in {**original_annotations, **class_dict}: + v = class_dict.get(k) if isinstance(v, RelationshipInfo): relationships[k] = v - else: + continue + r = get_annotated_relationshipinfo(original_annotations.get(k)) + if r is not None: + relationships[k] = r + + # populate dict passed to pydantic + for k, v in class_dict.items(): + if k not in relationships: dict_for_pydantic[k] = v - for k, v in original_annotations.items(): + + # split out pydantic annotations + for k, a in original_annotations.items(): if k in relationships: - relationship_annotations[k] = v + relationship_annotations[k] = a else: - pydantic_annotations[k] = v + pydantic_annotations[k] = a + dict_used = { **dict_for_pydantic, "__weakref__": None, @@ -643,6 +668,11 @@ def __init__( origin: Any = get_origin(raw_ann) if origin is Mapped: ann = raw_ann.__args__[0] + elif origin is Annotated: + ann = get_args(raw_ann)[0] + if get_origin(ann) is Mapped: + ann = ann.__args__[0] + cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type] else: ann = raw_ann # Plain forward references, for models not yet defined, are not diff --git a/tests/test_annotated_relationship.py b/tests/test_annotated_relationship.py new file mode 100644 index 0000000000..c953803efe --- /dev/null +++ b/tests/test_annotated_relationship.py @@ -0,0 +1,85 @@ +from typing import Annotated + +from sqlalchemy.orm import Mapped +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select + + +def test_annotated_relationship_with_default() -> None: + class Team(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + + heroes: Annotated[list["Hero"], Relationship(back_populates="team")] = [] # noqa: RUF012 + + class Hero(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + team_id: Annotated[int | None, Field(foreign_key="team.id")] = None + team: Annotated[Team | None, Relationship(back_populates="heroes")] = None + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + team = Team(name="Preventers") + hero = Hero(name="Deadpond", team=team) + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.team is not None + assert hero.team.name == "Preventers" + team_db = session.exec(select(Team)).one() + assert [h.name for h in team_db.heroes] == ["Deadpond"] + + +def test_annotated_relationship_without_default() -> None: + class Team(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + + heroes: Annotated[list["Hero"], Relationship(back_populates="team")] + + class Hero(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + team_id: Annotated[int | None, Field(foreign_key="team.id")] = None + team: Annotated[Team | None, Relationship(back_populates="heroes")] + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + team = Team(name="Z-Force") + hero = Hero(name="Spider-Boy", team=team) + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.team is not None + assert hero.team.name == "Z-Force" + + +def test_annotated_mapped_relationship() -> None: + class Team(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + + heroes: Annotated[ + Mapped[list["Hero"]], Relationship(back_populates="team") + ] = [] # noqa: RUF012 + + class Hero(SQLModel, table=True): + id: Annotated[int | None, Field(primary_key=True)] = None + name: Annotated[str, Field(index=True)] + team_id: Annotated[int | None, Field(foreign_key="team.id")] = None + team: Annotated[Mapped[Team | None], Relationship(back_populates="heroes")] = ( + None + ) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + team = Team(name="Avengers") + hero = Hero(name="Iron Man", team=team) + session.add(hero) + session.commit() + session.refresh(hero) + assert hero.team is not None + assert hero.team.name == "Avengers" diff --git a/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py b/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py index fdd1ce6443..74e0a754eb 100644 --- a/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py +++ b/tests/test_tutorial/test_relationship_attributes/test_define_relationship_attributes/test_tutorial001.py @@ -11,6 +11,7 @@ name="mod", params=[ pytest.param("tutorial001_py310"), + pytest.param("tutorial001_an_py310"), ], ) def get_module(request: pytest.FixtureRequest) -> ModuleType: