From 9bba5b0a4ead6bf13109f8f77c16731f109c1d2b Mon Sep 17 00:00:00 2001 From: vikynoah Date: Fri, 15 Nov 2024 12:39:30 +0100 Subject: [PATCH] fix: N 271 async db (#17) --- creyPY/fastapi/crud.py | 217 +++++++++++++++++++++++++++++++++++------ 1 file changed, 187 insertions(+), 30 deletions(-) diff --git a/creyPY/fastapi/crud.py b/creyPY/fastapi/crud.py index 2a68978..609dd1e 100644 --- a/creyPY/fastapi/crud.py +++ b/creyPY/fastapi/crud.py @@ -1,63 +1,220 @@ -from typing import Type, TypeVar +from typing import Type, TypeVar, overload from uuid import UUID from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session - +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +import asyncio from .models.base import Base T = TypeVar("T", bound=Base) +@overload +async def get_object_or_404( + db_class: Type[T], + id: UUID | str, + db: AsyncSession, + expunge: bool = False, + lookup_column: str = "id" +) -> T: + pass + +@overload +def get_object_or_404( + db_class: Type[T], + id: UUID | str, + db: Session, + expunge: bool = False, + lookup_column: str = "id" +) -> T: + pass def get_object_or_404( - db_class: Type[T], id: UUID | str, db: Session, expunge: bool = False, lookup_column: str = "id" + db_class: Type[T], id: UUID | str, db: Session | AsyncSession, expunge: bool = False, lookup_column: str = "id" ) -> T: - obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() - if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") - if expunge: - db.expunge(obj) - return obj + + async def _get_async_object() -> T: + query = select(db_class).filter(getattr(db_class, lookup_column) == id) + result = await db.execute(query) + obj = result.scalar_one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + if expunge: + await db.expunge(obj) + return obj + + def _get_sync_object() -> T: + obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + if expunge: + db.expunge(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_get_async_object()) # type: ignore + elif isinstance(db, Session): + return _get_sync_object() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + # TODO: Add testing +@overload +async def create_obj_from_data( + data: BaseModel, + model: Type[T], + db: AsyncSession, + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +@overload def create_obj_from_data( - data: BaseModel, model: Type[T], db: Session, additional_data={}, exclude={} + data: BaseModel, + model: Type[T], + db: Session, + additional_data: dict = {}, + exclude: dict = {} ) -> T: - obj = model(**data.model_dump(exclude=exclude) | additional_data) - db.add(obj) - db.commit() - db.refresh(obj) - return obj + pass + +def create_obj_from_data( + data: BaseModel, + model: Type[T], + db: Session | AsyncSession, + additional_data={}, + exclude={} +) -> T: + obj_data = data.model_dump(exclude=exclude) | additional_data + obj = model(**obj_data) + + async def _create_async_obj(): + db.add(obj) + await db.commit() + await db.refresh(obj) + return obj + + def _create_sync_obj(): + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_create_async_obj()) # type: ignore + elif isinstance(db, Session): + return _create_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing +@overload +async def update_obj_from_data( + data: BaseModel, + model: Type[T], + id: UUID | str, + db: AsyncSession, + partial: bool = True, + ignore_fields: list = [], + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +@overload def update_obj_from_data( data: BaseModel, model: Type[T], id: UUID | str, db: Session, partial: bool = True, + ignore_fields: list = [], + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +def update_obj_from_data( + data: BaseModel, + model: Type[T], + id: UUID | str, + db: Session | AsyncSession, + partial: bool = True, ignore_fields=[], additional_data={}, exclude={}, ) -> T: - obj = get_object_or_404(model, id, db) - data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) - data_dict.update(additional_data) # merge additional_data into data_dict - for field in data_dict: - if field not in ignore_fields: - setattr(obj, field, data_dict[field]) - db.commit() - db.refresh(obj) - return obj + def _update_fields(obj: T): + data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) + data_dict.update(additional_data) + + for field in data_dict: + if field not in ignore_fields: + setattr(obj, field, data_dict[field]) + + async def _update_async_obj() -> T: + obj = await get_object_or_404(model, id, db) + _update_fields(obj) + await db.commit() + await db.refresh(obj) + return obj + def _update_sync_obj() -> T: + obj = get_object_or_404(model, id, db) + _update_fields(obj) + db.commit() + db.refresh(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_update_async_obj()) # type: ignore + elif isinstance(db, Session): + return _update_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing -def delete_object(db_class: Type[T], id: UUID | str, db: Session) -> None: - obj = db.query(db_class).filter(db_class.id == id).one_or_none() - if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") - db.delete(obj) - db.commit() +@overload +async def delete_object( + db_class: Type[T], id: UUID | str, db: AsyncSession +) -> None: + pass + +@overload +def delete_object( + db_class: Type[T], id: UUID | str, db: Session +) -> None: + pass + + +def delete_object( + db_class: Type[T], id: UUID | str, db: Session | AsyncSession +) -> None: + async def _delete_async_obj() -> None: + query = select(db_class).filter(db_class.id == id) + result = await db.execute(query) + obj = result.scalar_one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") + await db.delete(obj) + await db.commit() + + def _delete_sync_obj() -> None: + obj = db.query(db_class).filter(db_class.id == id).one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") + db.delete(obj) + db.commit() + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_delete_async_obj()) # type: ignore + elif isinstance(db, Session): + return _delete_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore