Rework API
This commit is contained in:
@@ -14,13 +14,19 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||
|
||||
class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
|
||||
class Config:
|
||||
required_fields = ["database_uri", "token"]
|
||||
fields = ["database_uri", "token", "log_channel"]
|
||||
|
||||
database_uri: str
|
||||
token: str
|
||||
log_channel: str
|
||||
|
||||
def __init__(self, database_uri: str, token: str, log_channel: str = None):
|
||||
self.database_uri = database_uri
|
||||
self.token = token
|
||||
self.log_channel = log_channel
|
||||
|
||||
@staticmethod
|
||||
def from_file_and_env(filename: str) -> "Config":
|
||||
try:
|
||||
@@ -36,7 +42,7 @@ class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
|
||||
raise e
|
||||
|
||||
# Override with environment variables
|
||||
for f in Config._fields:
|
||||
for f in Config.fields:
|
||||
if f.upper() in os.environ:
|
||||
config[f] = os.environ[f.upper()]
|
||||
|
||||
|
||||
@@ -38,6 +38,11 @@ class Member(namedtuple("Member",
|
||||
"suffix": self.suffix
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_member_by_id(conn, member_id: int) -> Optional["Member"]:
|
||||
"""Fetch a member with the given internal member ID from the database."""
|
||||
return await db.get_member(conn, member_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_member_by_name(conn, system_id: int, member_name: str) -> "Optional[Member]":
|
||||
"""Fetch a member by the given name in the given system from the database."""
|
||||
|
||||
@@ -21,8 +21,8 @@ class Switch(namedtuple("Switch", ["id", "system", "timestamp", "members"])):
|
||||
async def move(self, conn, new_timestamp):
|
||||
await db.move_switch(conn, self.system, self.id, new_timestamp)
|
||||
|
||||
async def to_json(self, conn):
|
||||
async def to_json(self, hid_getter):
|
||||
return {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"members": [member.hid for member in await self.fetch_members(conn)]
|
||||
"members": [await hid_getter(m) for m in self.members]
|
||||
}
|
||||
|
||||
@@ -38,6 +38,10 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
||||
@staticmethod
|
||||
async def get_by_token(conn, token: str) -> Optional["System"]:
|
||||
return await db.get_system_by_token(conn, token)
|
||||
|
||||
@staticmethod
|
||||
async def get_by_hid(conn, hid: str) -> Optional["System"]:
|
||||
return await db.get_system_by_hid(conn, hid)
|
||||
|
||||
@staticmethod
|
||||
async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System":
|
||||
@@ -234,7 +238,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
||||
:returns: The `pytz.tzinfo` instance of the newly set time zone.
|
||||
"""
|
||||
|
||||
tz = pytz.timezone(tz_name or "UTC")
|
||||
try:
|
||||
tz = pytz.timezone(tz_name or "UTC")
|
||||
except pytz.UnknownTimeZoneError:
|
||||
raise errors.InvalidTimeZoneError(tz_name)
|
||||
|
||||
await db.update_system_field(conn, self.id, "ui_tz", tz.zone)
|
||||
return tz
|
||||
|
||||
@@ -304,5 +312,6 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"tag": self.tag,
|
||||
"avatar_url": self.avatar_url
|
||||
"avatar_url": self.avatar_url,
|
||||
"tz": self.ui_tz
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user