Response filter async (#45)

* fix: get_object alter for async response filter

* fix: Alter async response
This commit is contained in:
vikynoah
2025-04-04 17:54:47 +02:00
committed by GitHub
parent c903266ec4
commit badf2b157f
2 changed files with 14 additions and 10 deletions

View File

@@ -1,12 +1,13 @@
from typing import Type, TypeVar, overload, List
import asyncio
from typing import List, Type, TypeVar, overload
from uuid import UUID
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
import asyncio
from sqlalchemy.orm import Session
from .models.base import Base
T = TypeVar("T", bound=Base)
@@ -53,22 +54,22 @@ def get_object_or_404(
query = select(*selected_columns).where(getattr(db_class, lookup_column) == id)
result = await db.execute(query)
row = result.first()
if row is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
if hasattr(row, "_mapping"):
obj_dict = dict(row._mapping)
else:
obj_dict = {column.key: getattr(row, column.key) for column in selected_columns}
obj_dict = {column.key: getattr(row, column.key)
for column in selected_columns}
else:
query = select(db_class).where(getattr(db_class, lookup_column) == id)
result = await db.execute(query)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
obj_dict = {k: v for k, v in row.__dict__.items() if not k.startswith("_")}
obj_dict = row
if expunge:
await db.expunge(obj_dict)
return obj_dict

View File

@@ -1,6 +1,7 @@
from typing import List, Optional, Type
from pydantic import BaseModel, create_model
from fastapi import Query
from pydantic import BaseModel, create_model
class ResponseModelDependency:
@@ -8,8 +9,10 @@ class ResponseModelDependency:
self.model_class = model_class
def __call__(self, response_fields: Optional[List[str]] = Query(None)) -> Type[BaseModel]:
def process_result(result, fields=None):
def process_result(result, fields=None, async_session=False):
if not fields:
if async_session:
return {k: v for k, v in result.__dict__.items() if not k.startswith('_')}
return result
if hasattr(result, "_fields"):