maubot-sixtyfour/sixtyfour/db.py
2021-03-10 13:23:54 -07:00

65 lines
2.2 KiB
Python

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, NamedTuple
from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, String, Text, or_
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, relationship
from mautrix.types import UserID
AuthInfo = NamedTuple('AuthInfo', server=str, api_token=str)
AliasInfo = NamedTuple('AliasInfo', server=str, alias=str)
Base = declarative_base()
from pprint import pprint
class GCMToken(Base):
__tablename__ = "gcmtoken"
user_id: UserID = Column(String(255), primary_key=True, nullable=False)
api_token = Column(Text, nullable=False)
class Database:
db: Engine
def __init__(self, db: Engine) -> None:
self.db = db
Base.metadata.create_all(db)
self.Session = sessionmaker(bind=self.db)
def add_gcm(self, mxid: UserID, token: str) -> None:
s = self.Session()
s.add(GCMToken(user_id=mxid, api_token=token))
s.commit()
def rm_gcm(self, mxid: UserID) -> None:
s = self.Session()
token = s.query(GCMToken).get((mxid))
s.delete(token)
s.commit()
def has_gcm(self, user_id: UserID) -> bool:
s: Session = self.Session()
return s.query(GCMToken).filter(GCMToken.user_id == user_id).count() > 0
def get_gcm(self, user_id: UserID) -> str:
s = self.Session()
row = s.query(GCMToken).filter(GCMToken.user_id == user_id).scalar()
if row:
return row.api_token
return None