Compare commits

...

15 Commits

Author SHA1 Message Date
creyD
53ed939451 Adjusted files for isort & autopep 2025-02-24 17:27:48 +00:00
c56d14c2fd Merge pull request #35 from vikynoah/invite_bug
fix: add company id to invite user
2025-02-24 18:27:17 +01:00
vikynoah
1e9bcb92b6 fix: add company id to invite user 2025-02-24 11:49:17 +01:00
5e16bd5cbc fix: fixed issue that creyPY couldn't be used without PSQL 2025-02-19 10:27:51 +01:00
creyD
50b444be89 Adjusted files for isort & autopep 2025-02-14 09:10:55 +00:00
e12c86e352 Merge pull request #34 from vikynoah/obj_lifecycle_patch
feat : Add Patch to obj lifecycle
2025-02-14 10:10:21 +01:00
vikynoah
0708a48301 feat : Add Patch to obj lifecycle 2025-02-13 02:05:15 +01:00
34595d52f2 Merge pull request #33 from creyD/renovate/stripe-11.x
feat(deps): update dependency stripe to v11.5.0
2025-02-05 09:38:28 +01:00
renovate[bot]
421725ad10 feat(deps): update dependency stripe to v11.5.0 2025-01-27 22:02:52 +00:00
31c4cbb055 fix: fixed multiple bugs in database handling 2025-01-27 16:26:26 +01:00
410ae12f8e feat: added ssl option to test database 2025-01-27 13:16:55 +01:00
1f224c44bc feat: added sslmode flag 2025-01-27 13:09:16 +01:00
5b0cc0d87d fix: fixed tests 2025-01-24 19:10:36 +01:00
ecfc0fc167 fix: fixed issue with new mixin 2025-01-24 19:04:07 +01:00
eb62c87679 feat: added experimental init and annotation mixins 2025-01-24 18:58:39 +01:00
11 changed files with 78 additions and 21 deletions

View File

@@ -1,3 +1,8 @@
from .async_session import * # noqa try:
from .helpers import * # noqa import sqlalchemy
from .session import * # noqa
from .async_session import *
from .helpers import *
from .session import *
except ImportError:
print("SQLAlchemy not installed. Database functionality will be disabled.")

View File

@@ -1,15 +1,14 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from .common import SQLALCHEMY_DATABASE_URL, name from .common import SQLALCHEMY_DATABASE_URL, name, ssl_mode
async_engine = create_async_engine( async_engine = create_async_engine(
SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True, connect_args={"sslmode": "require"} SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True, connect_args={"sslmode": ssl_mode}
) )
AsyncSessionLocal = sessionmaker( AsyncSessionLocal = async_sessionmaker(
bind=async_engine, bind=async_engine,
class_=AsyncSession, class_=AsyncSession,
expire_on_commit=False, expire_on_commit=False,

View File

@@ -10,4 +10,6 @@ password = os.getenv("POSTGRES_PASSWORD", "root")
port = os.getenv("POSTGRES_PORT", "5432") port = os.getenv("POSTGRES_PORT", "5432")
name = os.getenv("POSTGRES_DB", "fastapi") name = os.getenv("POSTGRES_DB", "fastapi")
ssl_mode = os.getenv("SSL_MODE", "require")
SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port}/" SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg://{user}:{password}@{host}:{port}/"

View File

@@ -4,10 +4,10 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from .common import SQLALCHEMY_DATABASE_URL, name from .common import SQLALCHEMY_DATABASE_URL, name, ssl_mode
engine = create_engine( engine = create_engine(
SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True, connect_args={"sslmode": "require"} SQLALCHEMY_DATABASE_URL + name, pool_pre_ping=True, connect_args={"sslmode": ssl_mode}
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -1 +1,2 @@
from .base import * # noqa from .base import * # noqa
from .mixins import * # noqa

View File

@@ -7,9 +7,11 @@ from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import as_declarative from sqlalchemy.orm import as_declarative
from sqlalchemy.sql import func from sqlalchemy.sql import func
from .mixins import AutoAnnotateMixin, AutoInitMixin
@as_declarative() @as_declarative()
class Base: class Base(AutoAnnotateMixin, AutoInitMixin):
__abstract__ = True __abstract__ = True
# Primary key as uuid # Primary key as uuid
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)

View File

@@ -0,0 +1,36 @@
from sqlalchemy import Column
from sqlalchemy.orm import Mapped
class AutoAnnotateMixin:
@classmethod
def __init_subclass__(cls) -> None:
super().__init_subclass__()
annotations = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
annotations[key] = Mapped[value.type.python_type]
cls.__annotations__ = annotations
class AutoInitMixin:
@classmethod
def __init_subclass__(cls) -> None:
super().__init_subclass__()
init_params = []
for key, value in cls.__dict__.items():
if isinstance(value, Column):
if not value.nullable and not value.default and not value.server_default:
init_params.append((key, value.type.python_type))
def __init__(self, **kwargs):
super(cls, self).__init__()
for key, _ in init_params:
if key not in kwargs:
raise TypeError(f"Missing required argument: {key}")
setattr(self, key, kwargs[key])
for key, value in kwargs.items():
if key not in init_params and hasattr(self.__class__, key):
setattr(self, key, value)
cls.__init__ = __init__

View File

@@ -20,17 +20,21 @@ class AbstractTestAPI(unittest.IsolatedAsyncioTestCase):
transport=ASGITransport(app=app), base_url="http://testserver", follow_redirects=True transport=ASGITransport(app=app), base_url="http://testserver", follow_redirects=True
) )
cls.default_headers = headers cls.default_headers = headers
print("setting up abstract")
@classmethod @classmethod
def setup_database( def setup_database(
cls, sync_db_url: str, async_db_url: str, base: Type[Base], btree_gist: bool = False cls,
sync_db_url: str,
async_db_url: str,
base: Type[Base],
btree_gist: bool = False,
ssl_mode: str = "require",
): ):
cls.engine_s = create_engine( cls.engine_s = create_engine(
sync_db_url, sync_db_url,
echo=False, echo=False,
pool_pre_ping=True, pool_pre_ping=True,
connect_args={"sslmode": "require"}, connect_args={"sslmode": ssl_mode},
) )
if database_exists(cls.engine_s.url): if database_exists(cls.engine_s.url):
drop_database(cls.engine_s.url) drop_database(cls.engine_s.url)
@@ -47,7 +51,7 @@ class AbstractTestAPI(unittest.IsolatedAsyncioTestCase):
async_db_url, async_db_url,
echo=False, echo=False,
pool_pre_ping=True, pool_pre_ping=True,
connect_args={"sslmode": "require"}, connect_args={"sslmode": ssl_mode},
) )
async def get(self, url: str, r_code: int = 200, parse_json=True) -> dict | bytes: async def get(self, url: str, r_code: int = 200, parse_json=True) -> dict | bytes:
@@ -136,6 +140,7 @@ class AbstractTestAPI(unittest.IsolatedAsyncioTestCase):
pagination: bool = True, pagination: bool = True,
id_field: str = "id", id_field: str = "id",
created_at_check: bool = True, created_at_check: bool = True,
patch: dict | None = None,
): ):
# GET LIST # GET LIST
re = await self.get(url) re = await self.get(url)
@@ -160,6 +165,14 @@ class AbstractTestAPI(unittest.IsolatedAsyncioTestCase):
re = await self.get(f"{url}{obj_id}/") re = await self.get(f"{url}{obj_id}/")
self.assertEqual(re[id_field], obj_id) self.assertEqual(re[id_field], obj_id)
# PATCH
if patch:
for key, value in patch.items():
input_obj[key] = value
re = await self.patch(f"{url}{obj_id}/", obj=input_obj)
for key, value in patch.items():
self.assertEqual(re[key], value)
# GET LIST # GET LIST
re = await self.get(url) re = await self.get(url)
if pagination: if pagination:

View File

@@ -101,7 +101,7 @@ def request_verification_mail(sub: str) -> None:
return re.json() return re.json()
def create_user_invite(email: str) -> dict: def create_user_invite(email: str, company_id: str) -> dict:
re = requests.post( re = requests.post(
f"https://{AUTH0_DOMAIN}/api/v2/users", f"https://{AUTH0_DOMAIN}/api/v2/users",
headers={"Authorization": f"Bearer {get_management_token()}"}, headers={"Authorization": f"Bearer {get_management_token()}"},
@@ -111,6 +111,7 @@ def create_user_invite(email: str) -> dict:
"password": create_random_password(), "password": create_random_password(),
"verify_email": False, "verify_email": False,
"app_metadata": {"invitedToMyApp": True}, "app_metadata": {"invitedToMyApp": True},
"user_metadata": {"company_ids": [company_id]},
}, },
timeout=5, timeout=5,
) )

View File

@@ -1 +1 @@
stripe==11.4.1 # Stripe stripe==11.5.0 # Stripe

View File

@@ -7,9 +7,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from creyPY.fastapi.app import generate_unique_id from creyPY.fastapi.app import generate_unique_id
from creyPY.fastapi.crud import ( from creyPY.fastapi.crud import get_object_or_404
get_object_or_404,
)
from creyPY.fastapi.models.base import Base from creyPY.fastapi.models.base import Base
@@ -65,7 +63,7 @@ class TestMyFunction(unittest.TestCase):
def test_get_object_or_404_existing_object(self): def test_get_object_or_404_existing_object(self):
# Arrange # Arrange
obj_id = UUID("123e4567-e89b-12d3-a456-426614174000") obj_id = UUID("123e4567-e89b-12d3-a456-426614174000")
obj = MockDBClass(obj_id) obj = MockDBClass(id=obj_id)
self.db.add(obj) self.db.add(obj)
self.db.commit() self.db.commit()