Compare commits

..

2 Commits

Author SHA1 Message Date
creyD
7afb8e2fd8 Adjusted files for isort & autopep 2025-04-04 15:55:18 +00:00
vikynoah
badf2b157f Response filter async (#45)
* fix: get_object alter for async response filter

* fix: Alter async response
2025-04-04 17:54:47 +02:00
2 changed files with 10 additions and 7 deletions

View File

@@ -1,12 +1,13 @@
from typing import Type, TypeVar, overload, List import asyncio
from typing import List, Type, TypeVar, overload
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
import asyncio from sqlalchemy.orm import Session
from .models.base import Base from .models.base import Base
T = TypeVar("T", bound=Base) T = TypeVar("T", bound=Base)
@@ -64,11 +65,10 @@ def get_object_or_404(
query = select(db_class).where(getattr(db_class, lookup_column) == id) query = select(db_class).where(getattr(db_class, lookup_column) == id)
result = await db.execute(query) result = await db.execute(query)
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if row is None: if row is None:
raise HTTPException(status_code=404, detail="The object does not exist.") raise HTTPException(status_code=404, detail="The object does not exist.")
obj_dict = {k: v for k, v in row.__dict__.items() if not k.startswith("_")} obj_dict = row
if expunge: if expunge:
await db.expunge(obj_dict) await db.expunge(obj_dict)
return obj_dict return obj_dict

View File

@@ -1,6 +1,7 @@
from typing import List, Optional, Type from typing import List, Optional, Type
from pydantic import BaseModel, create_model
from fastapi import Query from fastapi import Query
from pydantic import BaseModel, create_model
class ResponseModelDependency: class ResponseModelDependency:
@@ -8,8 +9,10 @@ class ResponseModelDependency:
self.model_class = model_class self.model_class = model_class
def __call__(self, response_fields: Optional[List[str]] = Query(None)) -> Type[BaseModel]: def __call__(self, response_fields: Optional[List[str]] = Query(None)) -> Type[BaseModel]:
def process_result(result, fields=None): def process_result(result, fields=None, async_session=False):
if not fields: if not fields:
if async_session:
return {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
return result return result
if hasattr(result, "_fields"): if hasattr(result, "_fields"):