diff --git a/creyPY/fastapi/crud.py b/creyPY/fastapi/crud.py index 609dd1e..d8eb0df 100644 --- a/creyPY/fastapi/crud.py +++ b/creyPY/fastapi/crud.py @@ -11,28 +11,31 @@ from .models.base import Base T = TypeVar("T", bound=Base) + @overload async def get_object_or_404( - db_class: Type[T], - id: UUID | str, - db: AsyncSession, - expunge: bool = False, - lookup_column: str = "id" + db_class: Type[T], + id: UUID | str, + db: AsyncSession, + expunge: bool = False, + lookup_column: str = "id", ) -> T: pass + @overload def get_object_or_404( - db_class: Type[T], - id: UUID | str, - db: Session, - expunge: bool = False, - lookup_column: str = "id" + db_class: Type[T], id: UUID | str, db: Session, expunge: bool = False, lookup_column: str = "id" ) -> T: pass + def get_object_or_404( - db_class: Type[T], id: UUID | str, db: Session | AsyncSession, expunge: bool = False, lookup_column: str = "id" + db_class: Type[T], + id: UUID | str, + db: Session | AsyncSession, + expunge: bool = False, + lookup_column: str = "id", ) -> T: async def _get_async_object() -> T: @@ -40,7 +43,7 @@ def get_object_or_404( result = await db.execute(query) obj = result.scalar_one_or_none() if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore if expunge: await db.expunge(obj) return obj @@ -48,18 +51,17 @@ def get_object_or_404( def _get_sync_object() -> T: obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore if expunge: db.expunge(obj) return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_get_async_object()) # type: ignore + return asyncio.ensure_future(_get_async_object()) # type: ignore elif isinstance(db, Session): return _get_sync_object() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore - + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing @@ -69,26 +71,20 @@ async def create_obj_from_data( model: Type[T], db: AsyncSession, additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + @overload def create_obj_from_data( - data: BaseModel, - model: Type[T], - db: Session, - additional_data: dict = {}, - exclude: dict = {} + data: BaseModel, model: Type[T], db: Session, additional_data: dict = {}, exclude: dict = {} ) -> T: pass + def create_obj_from_data( - data: BaseModel, - model: Type[T], - db: Session | AsyncSession, - additional_data={}, - exclude={} + data: BaseModel, model: Type[T], db: Session | AsyncSession, additional_data={}, exclude={} ) -> T: obj_data = data.model_dump(exclude=exclude) | additional_data obj = model(**obj_data) @@ -106,11 +102,11 @@ def create_obj_from_data( return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_create_async_obj()) # type: ignore + return asyncio.ensure_future(_create_async_obj()) # type: ignore elif isinstance(db, Session): - return _create_sync_obj() + return _create_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing @@ -123,10 +119,11 @@ async def update_obj_from_data( partial: bool = True, ignore_fields: list = [], additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + @overload def update_obj_from_data( data: BaseModel, @@ -136,10 +133,11 @@ def update_obj_from_data( partial: bool = True, ignore_fields: list = [], additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + def update_obj_from_data( data: BaseModel, model: Type[T], @@ -153,11 +151,11 @@ def update_obj_from_data( def _update_fields(obj: T): data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) data_dict.update(additional_data) - + for field in data_dict: if field not in ignore_fields: setattr(obj, field, data_dict[field]) - + async def _update_async_obj() -> T: obj = await get_object_or_404(model, id, db) _update_fields(obj) @@ -173,29 +171,25 @@ def update_obj_from_data( return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_update_async_obj()) # type: ignore + return asyncio.ensure_future(_update_async_obj()) # type: ignore elif isinstance(db, Session): return _update_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + # TODO: Add testing @overload -async def delete_object( - db_class: Type[T], id: UUID | str, db: AsyncSession -) -> None: +async def delete_object(db_class: Type[T], id: UUID | str, db: AsyncSession) -> None: pass + @overload -def delete_object( - db_class: Type[T], id: UUID | str, db: Session -) -> None: +def delete_object(db_class: Type[T], id: UUID | str, db: Session) -> None: pass -def delete_object( - db_class: Type[T], id: UUID | str, db: Session | AsyncSession -) -> None: +def delete_object(db_class: Type[T], id: UUID | str, db: Session | AsyncSession) -> None: async def _delete_async_obj() -> None: query = select(db_class).filter(db_class.id == id) result = await db.execute(query) @@ -213,8 +207,8 @@ def delete_object( db.commit() if isinstance(db, AsyncSession): - return asyncio.ensure_future(_delete_async_obj()) # type: ignore + return asyncio.ensure_future(_delete_async_obj()) # type: ignore elif isinstance(db, Session): return _delete_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore