init: Initial commit
This commit is contained in:
@@ -1,16 +1,38 @@
|
||||
from models.team import Team
|
||||
from models.webhook import Webhook
|
||||
from models.service import Service
|
||||
from models.channel import Channel
|
||||
from models.base import Base
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.engine import URL
|
||||
from urllib.parse import quote_plus
|
||||
from uuid import UUID
|
||||
|
||||
class DatabaseConnectionString:
|
||||
def __init__(self, db_name: str, db_user: str, db_password: str, db_host: str, db_port: str):
|
||||
self.connection_string = f"postgresql://{db_user}:{quote_plus(db_password)}@{db_host}:{db_port}/{db_name}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.connection_string
|
||||
|
||||
class PostgresConnectionString(DatabaseConnectionString):
|
||||
def __init__(self, db_name: str, db_user: str, db_password: str, db_host: str, db_port: str):
|
||||
super().__init__(db_name, db_user, db_password, db_host, db_port)
|
||||
|
||||
class MySQLConnectionString(DatabaseConnectionString):
|
||||
def __init__(self, db_name: str, db_user: str, db_password: str, db_host: str, db_port: str):
|
||||
super().__init__(db_name, db_user, db_password, db_host, db_port)
|
||||
|
||||
class SQLiteConnectionString(DatabaseConnectionString):
|
||||
def __init__(self, db_path: str):
|
||||
super().__init__("", "", "", "", "")
|
||||
self.connection_string = f"sqlite:///{db_path}"
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, db_name: str, db_user: str, db_password: str, db_host: str, db_port: str):
|
||||
# URL encode the password to handle special characters
|
||||
encoded_password = quote_plus(db_password)
|
||||
self.engine = create_engine(f"postgresql://{db_user}:{encoded_password}@{db_host}:{db_port}/{db_name}")
|
||||
def __init__(self, connection_string: DatabaseConnectionString):
|
||||
self.engine = create_engine(str(connection_string))
|
||||
self.session = sessionmaker(bind=self.engine)
|
||||
self.__session = self.session()
|
||||
|
||||
@@ -37,8 +59,49 @@ class DatabaseManager:
|
||||
def get_webhooks(self) -> list[Webhook]:
|
||||
return self.__session.query(Webhook).all()
|
||||
|
||||
def count_webhooks(self) -> int:
|
||||
return len(self.get_webhooks())
|
||||
|
||||
def get_webhook_by_id(self, id: str) -> Webhook:
|
||||
webhook = self.__session.query(Webhook).filter(Webhook.id == id).first()
|
||||
if webhook is None:
|
||||
raise ValueError(f"Webhook with id {id} not found")
|
||||
return webhook
|
||||
return webhook
|
||||
|
||||
def get_services(self) -> list[Service]:
|
||||
return self.__session.query(Service).all()
|
||||
|
||||
def get_service_by_id(self, id: str) -> Service:
|
||||
service = self.__session.query(Service).filter(Service.id == id).first()
|
||||
if service is None:
|
||||
raise ValueError(f"Service with id {id} not found")
|
||||
return service
|
||||
|
||||
def get_channels(self) -> list[Channel]:
|
||||
return self.__session.query(Channel).all()
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Channel:
|
||||
channel = self.__session.query(Channel).filter(Channel.id == id).first()
|
||||
if channel is None:
|
||||
raise ValueError(f"Channel with id {id} not found")
|
||||
return channel
|
||||
|
||||
def get_channel_by_microsoft_channel_id(self, microsoft_channel_id: str) -> Channel | None:
|
||||
"""Get channel by Microsoft Teams channel ID. Returns None if not found."""
|
||||
try:
|
||||
# Convert string to UUID if needed
|
||||
channel_uuid = UUID(microsoft_channel_id) if isinstance(microsoft_channel_id, str) else microsoft_channel_id
|
||||
channel = self.__session.query(Channel).filter(Channel.microsoft_channel_id == channel_uuid).first()
|
||||
return channel
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def count_webhooks_by_channel_id(self, channel_id: str | UUID) -> int:
|
||||
"""Count webhooks by channel ID. Accepts UUID string or UUID object."""
|
||||
try:
|
||||
# Convert string to UUID if needed
|
||||
channel_uuid = UUID(channel_id) if isinstance(channel_id, str) else channel_id
|
||||
return len(self.__session.query(Webhook).filter(Webhook.channel_id == channel_uuid).all())
|
||||
except (ValueError, TypeError):
|
||||
# If channel_id is not a valid UUID, return 0
|
||||
return 0
|
||||
29
modules/template.py
Normal file
29
modules/template.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
class TemplateEngine:
|
||||
def __init__(self, template_name: str, template_path: str = "templates/") -> None:
|
||||
self.__name = template_name
|
||||
self.__path = template_path
|
||||
self.__template_path = f"{self.__path}{self.__name}.json"
|
||||
self.__template = self.__load_template()
|
||||
|
||||
def __load_template(self) -> dict[str, Any]:
|
||||
with open(self.__template_path, "r", encoding="utf-8") as f:
|
||||
return json.loads(f.read())
|
||||
|
||||
def __replace_placeholders(self, obj: Any, data: dict) -> Any:
|
||||
"""Recursively replace placeholders in dict/list/string values."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self.__replace_placeholders(v, data) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self.__replace_placeholders(item, data) for item in obj]
|
||||
elif isinstance(obj, str):
|
||||
return obj.format(**data)
|
||||
else:
|
||||
return obj
|
||||
|
||||
def generate(self, data: dict) -> dict[str, Any]:
|
||||
template = self.__load_template()
|
||||
return self.__replace_placeholders(template, data)
|
||||
Reference in New Issue
Block a user