Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Annotated

from sqlmodel import Field, Relationship, Session, SQLModel, create_engine


class Team(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
headquarters: str

heroes: Annotated[list["Hero"] | None, Relationship(back_populates="team")] = None


class Hero(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
secret_name: str
age: Annotated[int | None, Field(index=True)] = None

team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
team: Annotated[Team | None, Relationship(back_populates="heroes")] = None


sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url, echo=True)


def create_db_and_tables():
SQLModel.metadata.create_all(engine)


def create_heroes():
with Session(engine) as session:
team_preventers = Team(name="Preventers", headquarters="Sharp Tower")
team_z_force = Team(name="Z-Force", headquarters="Sister Margaret's Bar")

hero_deadpond = Hero(
name="Deadpond", secret_name="Dive Wilson", team=team_z_force
)
hero_rusty_man = Hero(
name="Rusty-Man", secret_name="Tommy Sharp", age=48, team=team_preventers
)
hero_spider_boy = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
session.add(hero_deadpond)
session.add(hero_rusty_man)
session.add(hero_spider_boy)
session.commit()

session.refresh(hero_deadpond)
session.refresh(hero_rusty_man)
session.refresh(hero_spider_boy)

print("Created hero:", hero_deadpond)
print("Created hero:", hero_rusty_man)
print("Created hero:", hero_spider_boy)

hero_spider_boy.team = team_preventers
session.add(hero_spider_boy)
session.commit()


def main():
create_db_and_tables()
create_heroes()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def get_relationship_to(
elif origin is list:
use_annotation = get_args(annotation)[0]

elif origin is Annotated:
use_annotation = get_args(annotation)[0]

return get_relationship_to(name=name, rel_info=rel_info, annotation=use_annotation)


Expand Down
40 changes: 35 additions & 5 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
TypeAlias,
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
)
Expand Down Expand Up @@ -517,6 +519,16 @@ def Relationship(
return relationship_info


def get_annotated_relationshipinfo(t: Any) -> RelationshipInfo | None:
"""Get the first RelationshipInfo from Annotated or None if not Annotated with RelationshipInfo."""
if get_origin(t) is not Annotated:
return None
for a in get_args(t):
if isinstance(a, RelationshipInfo):
return a
return None


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: dict[str, RelationshipInfo]
Expand Down Expand Up @@ -549,16 +561,29 @@ def __new__(
original_annotations = get_annotations(class_dict)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():

# find relationship info in both annotations and class dict
for k in {**original_annotations, **class_dict}:
v = class_dict.get(k)
if isinstance(v, RelationshipInfo):
relationships[k] = v
else:
continue
r = get_annotated_relationshipinfo(original_annotations.get(k))
if r is not None:
relationships[k] = r

# populate dict passed to pydantic
for k, v in class_dict.items():
if k not in relationships:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():

# split out pydantic annotations
for k, a in original_annotations.items():
if k in relationships:
relationship_annotations[k] = v
relationship_annotations[k] = a
else:
pydantic_annotations[k] = v
pydantic_annotations[k] = a

dict_used = {
**dict_for_pydantic,
"__weakref__": None,
Expand Down Expand Up @@ -643,6 +668,11 @@ def __init__(
origin: Any = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
elif origin is Annotated:
ann = get_args(raw_ann)[0]
if get_origin(ann) is Mapped:
ann = ann.__args__[0]
cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type]
else:
ann = raw_ann
# Plain forward references, for models not yet defined, are not
Expand Down
85 changes: 85 additions & 0 deletions tests/test_annotated_relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Annotated

from sqlalchemy.orm import Mapped
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select


def test_annotated_relationship_with_default() -> None:
class Team(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]

heroes: Annotated[list["Hero"], Relationship(back_populates="team")] = [] # noqa: RUF012

class Hero(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
team: Annotated[Team | None, Relationship(back_populates="heroes")] = None

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
team = Team(name="Preventers")
hero = Hero(name="Deadpond", team=team)
session.add(hero)
session.commit()
session.refresh(hero)
assert hero.team is not None
assert hero.team.name == "Preventers"
team_db = session.exec(select(Team)).one()
assert [h.name for h in team_db.heroes] == ["Deadpond"]


def test_annotated_relationship_without_default() -> None:
class Team(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]

heroes: Annotated[list["Hero"], Relationship(back_populates="team")]

class Hero(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
team: Annotated[Team | None, Relationship(back_populates="heroes")]

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
team = Team(name="Z-Force")
hero = Hero(name="Spider-Boy", team=team)
session.add(hero)
session.commit()
session.refresh(hero)
assert hero.team is not None
assert hero.team.name == "Z-Force"


def test_annotated_mapped_relationship() -> None:
class Team(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]

heroes: Annotated[
Mapped[list["Hero"]], Relationship(back_populates="team")
] = [] # noqa: RUF012

class Hero(SQLModel, table=True):
id: Annotated[int | None, Field(primary_key=True)] = None
name: Annotated[str, Field(index=True)]
team_id: Annotated[int | None, Field(foreign_key="team.id")] = None
team: Annotated[Mapped[Team | None], Relationship(back_populates="heroes")] = (
None
)

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
team = Team(name="Avengers")
hero = Hero(name="Iron Man", team=team)
session.add(hero)
session.commit()
session.refresh(hero)
assert hero.team is not None
assert hero.team.name == "Avengers"
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
name="mod",
params=[
pytest.param("tutorial001_py310"),
pytest.param("tutorial001_an_py310"),
],
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
Expand Down
Loading