From 2940ddbdcd7f8a3712f70e15646bfd10d7fdfee9 Mon Sep 17 00:00:00 2001 From: vikynoah Date: Tue, 12 Nov 2024 09:54:04 +0100 Subject: [PATCH 1/6] feat: Introduce ASYNC DB as Plug and Play (#16) Co-authored-by: vikbhas --- creyPY/fastapi/db/__init__.py | 1 + creyPY/fastapi/db/async_session.py | 25 ++++++++++ creyPY/fastapi/pagination.py | 77 ++++++++++++++++++++++++++++-- requirements.txt | 3 ++ 4 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 creyPY/fastapi/db/async_session.py diff --git a/creyPY/fastapi/db/__init__.py b/creyPY/fastapi/db/__init__.py index d13ef11..395efc9 100644 --- a/creyPY/fastapi/db/__init__.py +++ b/creyPY/fastapi/db/__init__.py @@ -1 +1,2 @@ from .session import * # noqa +from .async_session import * # noqa diff --git a/creyPY/fastapi/db/async_session.py b/creyPY/fastapi/db/async_session.py new file mode 100644 index 0000000..6856914 --- /dev/null +++ b/creyPY/fastapi/db/async_session.py @@ -0,0 +1,25 @@ +import os +from typing import AsyncGenerator +from dotenv import load_dotenv +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + + +load_dotenv() + +host = os.getenv("POSTGRES_HOST", "localhost") +user = os.getenv("POSTGRES_USER", "postgres") +password = os.getenv("POSTGRES_PASSWORD", "root") +port = os.getenv("POSTGRES_PORT", "5432") +name = os.getenv("POSTGRES_DB", "fastapi") + +SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port}/" + + +async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True) +AsyncSessionLocal = sessionmaker(bind=async_engine, class_=AsyncSession, + expire_on_commit=False, autoflush=False, autocommit=False) + +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSessionLocal() as db: + yield db diff --git a/creyPY/fastapi/pagination.py b/creyPY/fastapi/pagination.py index f96f698..d747c2d 100644 --- a/creyPY/fastapi/pagination.py +++ b/creyPY/fastapi/pagination.py @@ -1,5 +1,6 @@ from math import ceil -from typing import Any, Generic, Optional, Self, Sequence, TypeVar, Union +from typing import Any, Generic, Optional, Self, Sequence, TypeVar, Union, overload +from contextlib import suppress from pydantic import BaseModel from fastapi_pagination import Params from fastapi_pagination.bases import AbstractPage, AbstractParams @@ -8,6 +9,8 @@ from fastapi_pagination.types import ( GreaterEqualZero, AdditionalData, SyncItemsTransformer, + AsyncItemsTransformer, + ItemsTransformer ) from fastapi_pagination.api import create_page, apply_items_transformer from fastapi_pagination.utils import verify_params @@ -17,7 +20,9 @@ from pydantic.json_schema import SkipJsonSchema from sqlalchemy.sql.selectable import Select from sqlalchemy.orm.session import Session from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from fastapi import Query +from sqlalchemy.util import await_only, greenlet_spawn T = TypeVar("T") @@ -107,19 +112,59 @@ def unwrap_scalars( ) -> Union[Sequence[T], Sequence[Sequence[T]]]: return [item[0] if force_unwrap else item for item in items] +def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover + if isinstance(conn, async_scoped_session): + conn = conn() + with suppress(AttributeError): + return conn.sync_session # type: ignore + + with suppress(AttributeError): + return conn.sync_connection # type: ignore + + raise TypeError("conn must be an AsyncConnection or AsyncSession") + +@overload def paginate( connection: Session, query: Select, params: Optional[AbstractParams] = None, transformer: Optional[SyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, +) -> Any: + pass + + +@overload +async def paginate( + connection: AsyncSession, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[AsyncItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, +) -> Any: + pass + + +def _paginate( + connection: Session, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[ItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, + async_:bool = False ): + + if async_: + def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: + return await_only(apply_items_transformer(*args, **kwargs, async_=True)) + else: + _apply_items_transformer = apply_items_transformer params, raw_params = verify_params(params, "limit-offset", "cursor") count_query = create_count_query(query) total = connection.scalar(count_query) - + if params.pagination is False and total > 0: params = Params(page=1, size=total) else: @@ -129,7 +174,7 @@ def paginate( items = connection.execute(query).all() items = unwrap_scalars(items) - t_items = apply_items_transformer(items, transformer) + t_items = _apply_items_transformer(items, transformer) return create_page( t_items, @@ -137,3 +182,29 @@ def paginate( total=total, **(additional_data or {}), ) + +def paginate( + connection: Session, + query: Select, + params: Optional[AbstractParams] = None, + transformer: Optional[ItemsTransformer] = None, + additional_data: Optional[AdditionalData] = None, +): + if isinstance(connection,AsyncSession): + connection = _get_sync_conn_from_async(connection) + return greenlet_spawn(_paginate, + connection, + query, + params, + transformer, + additional_data, + async_=True) + + return _paginate( + connection, + query, + params, + transformer, + additional_data, + async_=False + ) diff --git a/requirements.txt b/requirements.txt index f8f5ec2..a03bdfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,6 @@ psycopg-pool>=3.2.2 # PostgreSQL h11>=0.14.0 # Testing httpcore>=1.0.5 # Testing httpx>=0.27.0 # Testing + +asyncpg>=0.30.0 #SQLAlchemy +greenlet>=3.1.1 #Async From 50031556f994633640e4777af735d0af05c170b7 Mon Sep 17 00:00:00 2001 From: creyD Date: Tue, 12 Nov 2024 08:54:34 +0000 Subject: [PATCH 2/6] Adjusted files for isort & autopep --- creyPY/fastapi/db/async_session.py | 10 +++++++-- creyPY/fastapi/pagination.py | 34 ++++++++++++------------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/creyPY/fastapi/db/async_session.py b/creyPY/fastapi/db/async_session.py index 6856914..07a0e61 100644 --- a/creyPY/fastapi/db/async_session.py +++ b/creyPY/fastapi/db/async_session.py @@ -17,8 +17,14 @@ SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port} async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True) -AsyncSessionLocal = sessionmaker(bind=async_engine, class_=AsyncSession, - expire_on_commit=False, autoflush=False, autocommit=False) +AsyncSessionLocal = sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, +) + async def get_async_db() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as db: diff --git a/creyPY/fastapi/pagination.py b/creyPY/fastapi/pagination.py index d747c2d..0aca060 100644 --- a/creyPY/fastapi/pagination.py +++ b/creyPY/fastapi/pagination.py @@ -10,7 +10,7 @@ from fastapi_pagination.types import ( AdditionalData, SyncItemsTransformer, AsyncItemsTransformer, - ItemsTransformer + ItemsTransformer, ) from fastapi_pagination.api import create_page, apply_items_transformer from fastapi_pagination.utils import verify_params @@ -112,6 +112,7 @@ def unwrap_scalars( ) -> Union[Sequence[T], Sequence[Sequence[T]]]: return [item[0] if force_unwrap else item for item in items] + def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover if isinstance(conn, async_scoped_session): conn = conn() @@ -124,6 +125,7 @@ def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover raise TypeError("conn must be an AsyncConnection or AsyncSession") + @overload def paginate( connection: Session, @@ -152,19 +154,21 @@ def _paginate( params: Optional[AbstractParams] = None, transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, - async_:bool = False + async_: bool = False, ): - + if async_: + def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: return await_only(apply_items_transformer(*args, **kwargs, async_=True)) + else: _apply_items_transformer = apply_items_transformer params, raw_params = verify_params(params, "limit-offset", "cursor") count_query = create_count_query(query) total = connection.scalar(count_query) - + if params.pagination is False and total > 0: params = Params(page=1, size=total) else: @@ -183,6 +187,7 @@ def _paginate( **(additional_data or {}), ) + def paginate( connection: Session, query: Select, @@ -190,21 +195,10 @@ def paginate( transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, ): - if isinstance(connection,AsyncSession): + if isinstance(connection, AsyncSession): connection = _get_sync_conn_from_async(connection) - return greenlet_spawn(_paginate, - connection, - query, - params, - transformer, - additional_data, - async_=True) - - return _paginate( - connection, - query, - params, - transformer, - additional_data, - async_=False + return greenlet_spawn( + _paginate, connection, query, params, transformer, additional_data, async_=True ) + + return _paginate(connection, query, params, transformer, additional_data, async_=False) From 9bba5b0a4ead6bf13109f8f77c16731f109c1d2b Mon Sep 17 00:00:00 2001 From: vikynoah Date: Fri, 15 Nov 2024 12:39:30 +0100 Subject: [PATCH 3/6] fix: N 271 async db (#17) --- creyPY/fastapi/crud.py | 217 +++++++++++++++++++++++++++++++++++------ 1 file changed, 187 insertions(+), 30 deletions(-) diff --git a/creyPY/fastapi/crud.py b/creyPY/fastapi/crud.py index 2a68978..609dd1e 100644 --- a/creyPY/fastapi/crud.py +++ b/creyPY/fastapi/crud.py @@ -1,63 +1,220 @@ -from typing import Type, TypeVar +from typing import Type, TypeVar, overload from uuid import UUID from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session - +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +import asyncio from .models.base import Base T = TypeVar("T", bound=Base) +@overload +async def get_object_or_404( + db_class: Type[T], + id: UUID | str, + db: AsyncSession, + expunge: bool = False, + lookup_column: str = "id" +) -> T: + pass + +@overload +def get_object_or_404( + db_class: Type[T], + id: UUID | str, + db: Session, + expunge: bool = False, + lookup_column: str = "id" +) -> T: + pass def get_object_or_404( - db_class: Type[T], id: UUID | str, db: Session, expunge: bool = False, lookup_column: str = "id" + db_class: Type[T], id: UUID | str, db: Session | AsyncSession, expunge: bool = False, lookup_column: str = "id" ) -> T: - obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() - if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") - if expunge: - db.expunge(obj) - return obj + + async def _get_async_object() -> T: + query = select(db_class).filter(getattr(db_class, lookup_column) == id) + result = await db.execute(query) + obj = result.scalar_one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + if expunge: + await db.expunge(obj) + return obj + + def _get_sync_object() -> T: + obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + if expunge: + db.expunge(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_get_async_object()) # type: ignore + elif isinstance(db, Session): + return _get_sync_object() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + # TODO: Add testing +@overload +async def create_obj_from_data( + data: BaseModel, + model: Type[T], + db: AsyncSession, + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +@overload def create_obj_from_data( - data: BaseModel, model: Type[T], db: Session, additional_data={}, exclude={} + data: BaseModel, + model: Type[T], + db: Session, + additional_data: dict = {}, + exclude: dict = {} ) -> T: - obj = model(**data.model_dump(exclude=exclude) | additional_data) - db.add(obj) - db.commit() - db.refresh(obj) - return obj + pass + +def create_obj_from_data( + data: BaseModel, + model: Type[T], + db: Session | AsyncSession, + additional_data={}, + exclude={} +) -> T: + obj_data = data.model_dump(exclude=exclude) | additional_data + obj = model(**obj_data) + + async def _create_async_obj(): + db.add(obj) + await db.commit() + await db.refresh(obj) + return obj + + def _create_sync_obj(): + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_create_async_obj()) # type: ignore + elif isinstance(db, Session): + return _create_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing +@overload +async def update_obj_from_data( + data: BaseModel, + model: Type[T], + id: UUID | str, + db: AsyncSession, + partial: bool = True, + ignore_fields: list = [], + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +@overload def update_obj_from_data( data: BaseModel, model: Type[T], id: UUID | str, db: Session, partial: bool = True, + ignore_fields: list = [], + additional_data: dict = {}, + exclude: dict = {} +) -> T: + pass + +def update_obj_from_data( + data: BaseModel, + model: Type[T], + id: UUID | str, + db: Session | AsyncSession, + partial: bool = True, ignore_fields=[], additional_data={}, exclude={}, ) -> T: - obj = get_object_or_404(model, id, db) - data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) - data_dict.update(additional_data) # merge additional_data into data_dict - for field in data_dict: - if field not in ignore_fields: - setattr(obj, field, data_dict[field]) - db.commit() - db.refresh(obj) - return obj + def _update_fields(obj: T): + data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) + data_dict.update(additional_data) + + for field in data_dict: + if field not in ignore_fields: + setattr(obj, field, data_dict[field]) + + async def _update_async_obj() -> T: + obj = await get_object_or_404(model, id, db) + _update_fields(obj) + await db.commit() + await db.refresh(obj) + return obj + def _update_sync_obj() -> T: + obj = get_object_or_404(model, id, db) + _update_fields(obj) + db.commit() + db.refresh(obj) + return obj + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_update_async_obj()) # type: ignore + elif isinstance(db, Session): + return _update_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing -def delete_object(db_class: Type[T], id: UUID | str, db: Session) -> None: - obj = db.query(db_class).filter(db_class.id == id).one_or_none() - if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") - db.delete(obj) - db.commit() +@overload +async def delete_object( + db_class: Type[T], id: UUID | str, db: AsyncSession +) -> None: + pass + +@overload +def delete_object( + db_class: Type[T], id: UUID | str, db: Session +) -> None: + pass + + +def delete_object( + db_class: Type[T], id: UUID | str, db: Session | AsyncSession +) -> None: + async def _delete_async_obj() -> None: + query = select(db_class).filter(db_class.id == id) + result = await db.execute(query) + obj = result.scalar_one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") + await db.delete(obj) + await db.commit() + + def _delete_sync_obj() -> None: + obj = db.query(db_class).filter(db_class.id == id).one_or_none() + if obj is None: + raise HTTPException(status_code=404, detail="The object does not exist.") + db.delete(obj) + db.commit() + + if isinstance(db, AsyncSession): + return asyncio.ensure_future(_delete_async_obj()) # type: ignore + elif isinstance(db, Session): + return _delete_sync_obj() + else: + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore From 6f09c2ef4cfbe0853258aefd00f3e8d199f331e3 Mon Sep 17 00:00:00 2001 From: creyD Date: Fri, 15 Nov 2024 11:39:59 +0000 Subject: [PATCH 4/6] Adjusted files for isort & autopep --- creyPY/fastapi/crud.py | 88 ++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 47 deletions(-) diff --git a/creyPY/fastapi/crud.py b/creyPY/fastapi/crud.py index 609dd1e..d8eb0df 100644 --- a/creyPY/fastapi/crud.py +++ b/creyPY/fastapi/crud.py @@ -11,28 +11,31 @@ from .models.base import Base T = TypeVar("T", bound=Base) + @overload async def get_object_or_404( - db_class: Type[T], - id: UUID | str, - db: AsyncSession, - expunge: bool = False, - lookup_column: str = "id" + db_class: Type[T], + id: UUID | str, + db: AsyncSession, + expunge: bool = False, + lookup_column: str = "id", ) -> T: pass + @overload def get_object_or_404( - db_class: Type[T], - id: UUID | str, - db: Session, - expunge: bool = False, - lookup_column: str = "id" + db_class: Type[T], id: UUID | str, db: Session, expunge: bool = False, lookup_column: str = "id" ) -> T: pass + def get_object_or_404( - db_class: Type[T], id: UUID | str, db: Session | AsyncSession, expunge: bool = False, lookup_column: str = "id" + db_class: Type[T], + id: UUID | str, + db: Session | AsyncSession, + expunge: bool = False, + lookup_column: str = "id", ) -> T: async def _get_async_object() -> T: @@ -40,7 +43,7 @@ def get_object_or_404( result = await db.execute(query) obj = result.scalar_one_or_none() if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore if expunge: await db.expunge(obj) return obj @@ -48,18 +51,17 @@ def get_object_or_404( def _get_sync_object() -> T: obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none() if obj is None: - raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore + raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore if expunge: db.expunge(obj) return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_get_async_object()) # type: ignore + return asyncio.ensure_future(_get_async_object()) # type: ignore elif isinstance(db, Session): return _get_sync_object() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore - + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing @@ -69,26 +71,20 @@ async def create_obj_from_data( model: Type[T], db: AsyncSession, additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + @overload def create_obj_from_data( - data: BaseModel, - model: Type[T], - db: Session, - additional_data: dict = {}, - exclude: dict = {} + data: BaseModel, model: Type[T], db: Session, additional_data: dict = {}, exclude: dict = {} ) -> T: pass + def create_obj_from_data( - data: BaseModel, - model: Type[T], - db: Session | AsyncSession, - additional_data={}, - exclude={} + data: BaseModel, model: Type[T], db: Session | AsyncSession, additional_data={}, exclude={} ) -> T: obj_data = data.model_dump(exclude=exclude) | additional_data obj = model(**obj_data) @@ -106,11 +102,11 @@ def create_obj_from_data( return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_create_async_obj()) # type: ignore + return asyncio.ensure_future(_create_async_obj()) # type: ignore elif isinstance(db, Session): - return _create_sync_obj() + return _create_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore # TODO: Add testing @@ -123,10 +119,11 @@ async def update_obj_from_data( partial: bool = True, ignore_fields: list = [], additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + @overload def update_obj_from_data( data: BaseModel, @@ -136,10 +133,11 @@ def update_obj_from_data( partial: bool = True, ignore_fields: list = [], additional_data: dict = {}, - exclude: dict = {} + exclude: dict = {}, ) -> T: pass + def update_obj_from_data( data: BaseModel, model: Type[T], @@ -153,11 +151,11 @@ def update_obj_from_data( def _update_fields(obj: T): data_dict = data.model_dump(exclude_unset=partial, exclude=exclude) data_dict.update(additional_data) - + for field in data_dict: if field not in ignore_fields: setattr(obj, field, data_dict[field]) - + async def _update_async_obj() -> T: obj = await get_object_or_404(model, id, db) _update_fields(obj) @@ -173,29 +171,25 @@ def update_obj_from_data( return obj if isinstance(db, AsyncSession): - return asyncio.ensure_future(_update_async_obj()) # type: ignore + return asyncio.ensure_future(_update_async_obj()) # type: ignore elif isinstance(db, Session): return _update_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + # TODO: Add testing @overload -async def delete_object( - db_class: Type[T], id: UUID | str, db: AsyncSession -) -> None: +async def delete_object(db_class: Type[T], id: UUID | str, db: AsyncSession) -> None: pass + @overload -def delete_object( - db_class: Type[T], id: UUID | str, db: Session -) -> None: +def delete_object(db_class: Type[T], id: UUID | str, db: Session) -> None: pass -def delete_object( - db_class: Type[T], id: UUID | str, db: Session | AsyncSession -) -> None: +def delete_object(db_class: Type[T], id: UUID | str, db: Session | AsyncSession) -> None: async def _delete_async_obj() -> None: query = select(db_class).filter(db_class.id == id) result = await db.execute(query) @@ -213,8 +207,8 @@ def delete_object( db.commit() if isinstance(db, AsyncSession): - return asyncio.ensure_future(_delete_async_obj()) # type: ignore + return asyncio.ensure_future(_delete_async_obj()) # type: ignore elif isinstance(db, Session): return _delete_sync_obj() else: - raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore + raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore From 523241ac4bdaf00f3228822ffe4d7831116b266b Mon Sep 17 00:00:00 2001 From: vikynoah Date: Fri, 22 Nov 2024 12:56:45 +0100 Subject: [PATCH 5/6] feat: N-271 async db (#18) --- creyPY/fastapi/testing_async.py | 137 ++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 creyPY/fastapi/testing_async.py diff --git a/creyPY/fastapi/testing_async.py b/creyPY/fastapi/testing_async.py new file mode 100644 index 0000000..e9da993 --- /dev/null +++ b/creyPY/fastapi/testing_async.py @@ -0,0 +1,137 @@ +import json +from httpx import AsyncClient + + +class AsyncGenericClient: + def __init__(self, app): + self.c = AsyncClient(app=app, base_url="http://testserver", follow_redirects=True) + self.default_headers = {} + + async def get(self, url: str, r_code: int = 200, parse_json=True): + re = await self.c.get(url, headers=self.default_headers) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if parse_json else re.content + + async def delete(self, url: str, r_code: int = 204): + re = await self.c.delete(url, headers=self.default_headers) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if r_code != 204 else None + + async def post( + self, url: str, obj: dict | str = {}, r_code: int = 201, raw_response=False, *args, **kwargs + ): + re = await self.c.post( + url, + data=json.dumps(obj) if isinstance(obj, dict) else obj, + headers=self.default_headers | {"Content-Type": "application/json"}, + *args, + **kwargs, + ) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if not raw_response else re + + async def post_file(self, url: str, file, r_code: int = 201, raw_response=False, *args, **kwargs): + re = await self.c.post( + url, + files={"file": file}, + headers=self.default_headers | {"Content-Type": "application/json"}, + *args, + **kwargs, + ) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if not raw_response else re + + async def patch( + self, url: str, obj: dict | str = {}, r_code: int = 200, raw_response=False, *args, **kwargs + ): + re = await self.c.patch( + url, + data=json.dumps(obj) if isinstance(obj, dict) else obj, + headers=self.default_headers | {"Content-Type": "application/json"}, + *args, + **kwargs, + ) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if not raw_response else re + + async def put( + self, url: str, obj: dict | str = {}, r_code: int = 200, raw_response=False, *args, **kwargs + ): + re = await self.c.put( + url, + data=json.dumps(obj) if isinstance(obj, dict) else obj, + headers=self.default_headers + | { + "Content-Type": "application/json", + "accept": "application/json", + }, + *args, + **kwargs, + ) + if re.status_code != r_code: + print(re.content) + assert r_code == re.status_code + return re.json() if not raw_response else re + + async def obj_lifecycle( + self, + input_obj: dict, + url: str, + pagination: bool = True, + id_field: str = "id", + created_at_check: bool = True, + ): + # GET LIST + re = await self.get(url) + if pagination: + assert re["total"] == 0 + assert len(re["results"]) == 0 + else: + assert len(re) == 0 + + # CREATE + re = await self.post(url, obj=input_obj) + assert id_field in re + assert re[id_field] is not None + + if created_at_check: + assert "created_at" in re + assert re["created_at"] is not None + + obj_id = str(re[id_field]) + + # GET + re = await self.get(f"{url}{obj_id}/") + assert re[id_field] == obj_id + + # GET LIST + re = await self.get(url) + if pagination: + assert re["total"] == 1 + assert len(re["results"]) == 1 + else: + assert len(re) == 1 + + # DELETE + await self.delete(f"{url}{obj_id}") + + # GET LIST + re = await self.get(url) + if pagination: + assert re["total"] == 0 + assert len(re["results"]) == 0 + else: + assert len(re) == 0 + + # GET + await self.get(f"{url}{obj_id}", parse_json=False, r_code=404) From 17f96c920d2a51369f9ed046e64160017eec1d7d Mon Sep 17 00:00:00 2001 From: creyD Date: Fri, 22 Nov 2024 11:58:05 +0000 Subject: [PATCH 6/6] Adjusted files for isort & autopep --- creyPY/fastapi/testing_async.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/creyPY/fastapi/testing_async.py b/creyPY/fastapi/testing_async.py index e9da993..91c835e 100644 --- a/creyPY/fastapi/testing_async.py +++ b/creyPY/fastapi/testing_async.py @@ -36,7 +36,9 @@ class AsyncGenericClient: assert r_code == re.status_code return re.json() if not raw_response else re - async def post_file(self, url: str, file, r_code: int = 201, raw_response=False, *args, **kwargs): + async def post_file( + self, url: str, file, r_code: int = 201, raw_response=False, *args, **kwargs + ): re = await self.c.post( url, files={"file": file},