Compare commits

...

5 Commits

Author SHA1 Message Date
b86b58f3e4 Merge pull request #19 from creyD/dev 2024-11-22 13:20:27 +01:00
creyD
17f96c920d Adjusted files for isort & autopep 2024-11-22 11:58:05 +00:00
vikynoah
523241ac4b feat: N-271 async db (#18) 2024-11-22 12:56:45 +01:00
creyD
6f09c2ef4c Adjusted files for isort & autopep 2024-11-15 11:39:59 +00:00
vikynoah
9bba5b0a4e fix: N 271 async db (#17) 2024-11-15 12:39:30 +01:00
2 changed files with 320 additions and 30 deletions

View File

@@ -1,63 +1,214 @@
from typing import Type, TypeVar
from typing import Type, TypeVar, overload
from uuid import UUID
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
import asyncio
from .models.base import Base
T = TypeVar("T", bound=Base)
@overload
async def get_object_or_404(
db_class: Type[T],
id: UUID | str,
db: AsyncSession,
expunge: bool = False,
lookup_column: str = "id",
) -> T:
pass
@overload
def get_object_or_404(
db_class: Type[T], id: UUID | str, db: Session, expunge: bool = False, lookup_column: str = "id"
) -> T:
obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
if expunge:
db.expunge(obj)
return obj
pass
# TODO: Add testing
def create_obj_from_data(
data: BaseModel, model: Type[T], db: Session, additional_data={}, exclude={}
def get_object_or_404(
db_class: Type[T],
id: UUID | str,
db: Session | AsyncSession,
expunge: bool = False,
lookup_column: str = "id",
) -> T:
obj = model(**data.model_dump(exclude=exclude) | additional_data)
db.add(obj)
db.commit()
db.refresh(obj)
return obj
async def _get_async_object() -> T:
query = select(db_class).filter(getattr(db_class, lookup_column) == id)
result = await db.execute(query)
obj = result.scalar_one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore
if expunge:
await db.expunge(obj)
return obj
def _get_sync_object() -> T:
obj = db.query(db_class).filter(getattr(db_class, lookup_column) == id).one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.") # type: ignore
if expunge:
db.expunge(obj)
return obj
if isinstance(db, AsyncSession):
return asyncio.ensure_future(_get_async_object()) # type: ignore
elif isinstance(db, Session):
return _get_sync_object()
else:
raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore
# TODO: Add testing
@overload
async def create_obj_from_data(
data: BaseModel,
model: Type[T],
db: AsyncSession,
additional_data: dict = {},
exclude: dict = {},
) -> T:
pass
@overload
def create_obj_from_data(
data: BaseModel, model: Type[T], db: Session, additional_data: dict = {}, exclude: dict = {}
) -> T:
pass
def create_obj_from_data(
data: BaseModel, model: Type[T], db: Session | AsyncSession, additional_data={}, exclude={}
) -> T:
obj_data = data.model_dump(exclude=exclude) | additional_data
obj = model(**obj_data)
async def _create_async_obj():
db.add(obj)
await db.commit()
await db.refresh(obj)
return obj
def _create_sync_obj():
db.add(obj)
db.commit()
db.refresh(obj)
return obj
if isinstance(db, AsyncSession):
return asyncio.ensure_future(_create_async_obj()) # type: ignore
elif isinstance(db, Session):
return _create_sync_obj()
else:
raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore
# TODO: Add testing
@overload
async def update_obj_from_data(
data: BaseModel,
model: Type[T],
id: UUID | str,
db: AsyncSession,
partial: bool = True,
ignore_fields: list = [],
additional_data: dict = {},
exclude: dict = {},
) -> T:
pass
@overload
def update_obj_from_data(
data: BaseModel,
model: Type[T],
id: UUID | str,
db: Session,
partial: bool = True,
ignore_fields: list = [],
additional_data: dict = {},
exclude: dict = {},
) -> T:
pass
def update_obj_from_data(
data: BaseModel,
model: Type[T],
id: UUID | str,
db: Session | AsyncSession,
partial: bool = True,
ignore_fields=[],
additional_data={},
exclude={},
) -> T:
obj = get_object_or_404(model, id, db)
data_dict = data.model_dump(exclude_unset=partial, exclude=exclude)
data_dict.update(additional_data) # merge additional_data into data_dict
for field in data_dict:
if field not in ignore_fields:
setattr(obj, field, data_dict[field])
db.commit()
db.refresh(obj)
return obj
def _update_fields(obj: T):
data_dict = data.model_dump(exclude_unset=partial, exclude=exclude)
data_dict.update(additional_data)
for field in data_dict:
if field not in ignore_fields:
setattr(obj, field, data_dict[field])
async def _update_async_obj() -> T:
obj = await get_object_or_404(model, id, db)
_update_fields(obj)
await db.commit()
await db.refresh(obj)
return obj
def _update_sync_obj() -> T:
obj = get_object_or_404(model, id, db)
_update_fields(obj)
db.commit()
db.refresh(obj)
return obj
if isinstance(db, AsyncSession):
return asyncio.ensure_future(_update_async_obj()) # type: ignore
elif isinstance(db, Session):
return _update_sync_obj()
else:
raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore
# TODO: Add testing
@overload
async def delete_object(db_class: Type[T], id: UUID | str, db: AsyncSession) -> None:
pass
@overload
def delete_object(db_class: Type[T], id: UUID | str, db: Session) -> None:
obj = db.query(db_class).filter(db_class.id == id).one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
db.delete(obj)
db.commit()
pass
def delete_object(db_class: Type[T], id: UUID | str, db: Session | AsyncSession) -> None:
async def _delete_async_obj() -> None:
query = select(db_class).filter(db_class.id == id)
result = await db.execute(query)
obj = result.scalar_one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
await db.delete(obj)
await db.commit()
def _delete_sync_obj() -> None:
obj = db.query(db_class).filter(db_class.id == id).one_or_none()
if obj is None:
raise HTTPException(status_code=404, detail="The object does not exist.")
db.delete(obj)
db.commit()
if isinstance(db, AsyncSession):
return asyncio.ensure_future(_delete_async_obj()) # type: ignore
elif isinstance(db, Session):
return _delete_sync_obj()
else:
raise HTTPException(status_code=404, detail="Invalid session type. Expected Session or AsyncSession.") # type: ignore

View File

@@ -0,0 +1,139 @@
import json
from httpx import AsyncClient
class AsyncGenericClient:
def __init__(self, app):
self.c = AsyncClient(app=app, base_url="http://testserver", follow_redirects=True)
self.default_headers = {}
async def get(self, url: str, r_code: int = 200, parse_json=True):
re = await self.c.get(url, headers=self.default_headers)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if parse_json else re.content
async def delete(self, url: str, r_code: int = 204):
re = await self.c.delete(url, headers=self.default_headers)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if r_code != 204 else None
async def post(
self, url: str, obj: dict | str = {}, r_code: int = 201, raw_response=False, *args, **kwargs
):
re = await self.c.post(
url,
data=json.dumps(obj) if isinstance(obj, dict) else obj,
headers=self.default_headers | {"Content-Type": "application/json"},
*args,
**kwargs,
)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if not raw_response else re
async def post_file(
self, url: str, file, r_code: int = 201, raw_response=False, *args, **kwargs
):
re = await self.c.post(
url,
files={"file": file},
headers=self.default_headers | {"Content-Type": "application/json"},
*args,
**kwargs,
)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if not raw_response else re
async def patch(
self, url: str, obj: dict | str = {}, r_code: int = 200, raw_response=False, *args, **kwargs
):
re = await self.c.patch(
url,
data=json.dumps(obj) if isinstance(obj, dict) else obj,
headers=self.default_headers | {"Content-Type": "application/json"},
*args,
**kwargs,
)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if not raw_response else re
async def put(
self, url: str, obj: dict | str = {}, r_code: int = 200, raw_response=False, *args, **kwargs
):
re = await self.c.put(
url,
data=json.dumps(obj) if isinstance(obj, dict) else obj,
headers=self.default_headers
| {
"Content-Type": "application/json",
"accept": "application/json",
},
*args,
**kwargs,
)
if re.status_code != r_code:
print(re.content)
assert r_code == re.status_code
return re.json() if not raw_response else re
async def obj_lifecycle(
self,
input_obj: dict,
url: str,
pagination: bool = True,
id_field: str = "id",
created_at_check: bool = True,
):
# GET LIST
re = await self.get(url)
if pagination:
assert re["total"] == 0
assert len(re["results"]) == 0
else:
assert len(re) == 0
# CREATE
re = await self.post(url, obj=input_obj)
assert id_field in re
assert re[id_field] is not None
if created_at_check:
assert "created_at" in re
assert re["created_at"] is not None
obj_id = str(re[id_field])
# GET
re = await self.get(f"{url}{obj_id}/")
assert re[id_field] == obj_id
# GET LIST
re = await self.get(url)
if pagination:
assert re["total"] == 1
assert len(re["results"]) == 1
else:
assert len(re) == 1
# DELETE
await self.delete(f"{url}{obj_id}")
# GET LIST
re = await self.get(url)
if pagination:
assert re["total"] == 0
assert len(re["results"]) == 0
else:
assert len(re) == 0
# GET
await self.get(f"{url}{obj_id}", parse_json=False, r_code=404)