diff --git a/creyPY/fastapi/pagination.py b/creyPY/fastapi/pagination.py index 2973e23..98513ff 100644 --- a/creyPY/fastapi/pagination.py +++ b/creyPY/fastapi/pagination.py @@ -1,6 +1,6 @@ from math import ceil from typing import Any, Generic, Optional, Self, Sequence, TypeVar, Union - +from pydantic import BaseModel from fastapi_pagination import Params from fastapi_pagination.bases import AbstractPage, AbstractParams from fastapi_pagination.types import ( @@ -12,13 +12,29 @@ from fastapi_pagination.types import ( from fastapi_pagination.api import create_page, apply_items_transformer from fastapi_pagination.utils import verify_params from fastapi_pagination.ext.sqlalchemy import create_paginate_query +from fastapi_pagination.bases import AbstractParams, RawParams from pydantic.json_schema import SkipJsonSchema from sqlalchemy.sql.selectable import Select from sqlalchemy.orm.session import Session from sqlalchemy import select, func +from fastapi import Query T = TypeVar("T") +class PaginationParams(BaseModel, AbstractParams): + page: int = Query(1, ge=1, description="Page number") + size: int = Query(50, ge=1, le=100, description="Page size") + pagination: bool = Query(True, description="Toggle pagination") + + def to_raw_params(self) -> RawParams: + if not self.pagination: + return RawParams(limit=None, offset=None) + + return RawParams( + limit=self.size, + offset=(self.page - 1) * self.size + ) + # TODO: Add complete fastapi-pagination proxy here # TODO: Add pagination off functionality @@ -32,7 +48,7 @@ class Page(AbstractPage[T], Generic[T]): has_next: bool | SkipJsonSchema[None] = None has_prev: bool | SkipJsonSchema[None] = None - __params_type__ = Params + __params_type__ = PaginationParams @classmethod def create( @@ -97,18 +113,19 @@ def unwrap_scalars( def paginate( connection: Session, query: Select, - paginationFlag: bool = True, params: Optional[AbstractParams] = None, transformer: Optional[SyncItemsTransformer] = None, additional_data: Optional[AdditionalData] = None, ): - params, _ = verify_params(params, "limit-offset", "cursor") + params, raw_params = verify_params(params, "limit-offset", "cursor") count_query = create_count_query(query) total = connection.scalar(count_query) - if paginationFlag is False and total > 0: + if params.pagination is False and total > 0: params = Params(page=1, size=total) + else: + params = Params(page=params.page, size=params.size) query = create_paginate_query(query, params) items = connection.execute(query).all()