diff --git a/creyPY/fastapi/db/async_session.py b/creyPY/fastapi/db/async_session.py index 6856914..07a0e61 100644 --- a/creyPY/fastapi/db/async_session.py +++ b/creyPY/fastapi/db/async_session.py @@ -17,8 +17,14 @@ SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port} async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True) -AsyncSessionLocal = sessionmaker(bind=async_engine, class_=AsyncSession, - expire_on_commit=False, autoflush=False, autocommit=False) +AsyncSessionLocal = sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, +) + async def get_async_db() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as db: diff --git a/creyPY/fastapi/pagination.py b/creyPY/fastapi/pagination.py index d747c2d..0aca060 100644 --- a/creyPY/fastapi/pagination.py +++ b/creyPY/fastapi/pagination.py @@ -10,7 +10,7 @@ from fastapi_pagination.types import ( AdditionalData, SyncItemsTransformer, AsyncItemsTransformer, - ItemsTransformer + ItemsTransformer, ) from fastapi_pagination.api import create_page, apply_items_transformer from fastapi_pagination.utils import verify_params @@ -112,6 +112,7 @@ def unwrap_scalars( ) -> Union[Sequence[T], Sequence[Sequence[T]]]: return [item[0] if force_unwrap else item for item in items] + def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover if isinstance(conn, async_scoped_session): conn = conn() @@ -124,6 +125,7 @@ def _get_sync_conn_from_async(conn: Any) -> Session: # pragma: no cover raise TypeError("conn must be an AsyncConnection or AsyncSession") + @overload def paginate( connection: Session, @@ -152,19 +154,21 @@ def _paginate( params: Optional[AbstractParams] = None, transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, - async_:bool = False + async_: bool = False, ): - + if async_: + def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any: return await_only(apply_items_transformer(*args, **kwargs, async_=True)) + else: _apply_items_transformer = apply_items_transformer params, raw_params = verify_params(params, "limit-offset", "cursor") count_query = create_count_query(query) total = connection.scalar(count_query) - + if params.pagination is False and total > 0: params = Params(page=1, size=total) else: @@ -183,6 +187,7 @@ def _paginate( **(additional_data or {}), ) + def paginate( connection: Session, query: Select, @@ -190,21 +195,10 @@ def paginate( transformer: Optional[ItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, ): - if isinstance(connection,AsyncSession): + if isinstance(connection, AsyncSession): connection = _get_sync_conn_from_async(connection) - return greenlet_spawn(_paginate, - connection, - query, - params, - transformer, - additional_data, - async_=True) - - return _paginate( - connection, - query, - params, - transformer, - additional_data, - async_=False + return greenlet_spawn( + _paginate, connection, query, params, transformer, additional_data, async_=True ) + + return _paginate(connection, query, params, transformer, additional_data, async_=False)