Source code for app.models

"""
app.models
==========

Application's database models.
"""

# pylint: disable=too-few-public-methods
from __future__ import annotations

import hashlib
import json
import typing as t
from datetime import datetime
from time import time

from flask import abort
from flask_login import UserMixin, current_user
from flask_sqlalchemy.query import Query
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy_continuum import make_versioned
from sqlalchemy_continuum.model_builder import ModelBuilder
from sqlalchemy_utils import generic_repr
from werkzeug.security import check_password_hash, generate_password_hash

from app.extensions import db

make_versioned(user_cls=None)

followers = db.Table(
    "followers",
    db.Column("follower_id", db.Integer, db.ForeignKey("user.id")),
    db.Column("followed_id", db.Integer, db.ForeignKey("user.id")),
)


[docs] @generic_repr("id") class BaseModel(db.Model): # type: ignore """Base model for which all models in this app are derived.""" __abstract__ = True
[docs] def export(self) -> dict[str, str]: """Get database attributes as a ``dict`` of ``str`` objects. :return: Database as a ``dict``. """ return { k: str(v) for k, v in self.__dict__.items() if not k.startswith("_") }
[docs] class User(UserMixin, BaseModel): """Database schema for users.""" #: ID of this user. id = db.Column(db.Integer, primary_key=True) #: Username of this user. username = db.Column(db.String(64), index=True, unique=True) #: Email of this user. email = db.Column(db.String(120), index=True, unique=True) #: Password hash of this user's password. password_hash = db.Column(db.String(128)) #: Date that the user was created. created = db.Column(db.DateTime, index=True, default=datetime.utcnow) #: Whether this this user is an admin or not. admin = db.Column(db.Boolean, default=False) #: Posts that have been added by the this user. posts = db.relationship("Post", backref="author", lazy="dynamic") #: Whether this this user is confirmed or not. confirmed = db.Column(db.Boolean, default=False) #: Date that the this user was confirmed on. confirmed_on = db.Column(db.DateTime) #: About me page for the this user. about_me = db.Column(db.String(140)) #: Date that the this user last logged in on. last_seen = db.Column(db.DateTime, default=datetime.utcnow) #: Whether this user is authorized to make posts or not. authorized = db.Column(db.Boolean, default=False) #: User's that this user if following. followed: RelationshipProperty = db.relationship( "User", secondary=followers, primaryjoin=(followers.c.follower_id == id), secondaryjoin=(followers.c.followed_id == id), backref=db.backref("followers", lazy="dynamic"), lazy="dynamic", ) #: Messages that this user has sent. messages_sent = db.relationship( "Message", foreign_keys="Message.sender_id", backref="author", lazy="dynamic", ) #: Messages this user has received. messages_received = db.relationship( "Message", foreign_keys="Message.recipient_id", backref="recipient", lazy="dynamic", ) #: The last time the user visited the messages page. last_message_read_time = db.Column(db.DateTime) #: Notifications pending for this user. notifications = db.relationship( "Notification", backref="user", lazy="dynamic" ) #: Tasks associated with this user. tasks = db.relationship("Task", backref="user", lazy="dynamic")
[docs] def set_password(self, password: str) -> None: """Hash and store a new user password. :param password: User's choice in password. """ self.password_hash = generate_password_hash(password)
[docs] def check_password(self, password: str) -> bool: """Match entered password against existing password hash. :param password: User's password attempt. :return: Attempt matches hash: True or False. """ return check_password_hash(self.password_hash, password)
[docs] def avatar(self, size: int) -> str: """Generate unique avatar for user derived from email hash. :param size: Size of the avatar. :return: URL leading to avatar for img link. """ digest = hashlib.new( # type: ignore "md5", self.email.lower().encode(), usedforsecurity=False ).hexdigest() return f"https://gravatar.com/avatar/{digest}?d=identicon&s={size}"
[docs] def follow(self, user: User) -> None: """Follow another user if not already following. :param user: User model object of user to follow. """ if not self.is_following(user): self.followed.append(user)
[docs] def unfollow(self, user: User) -> None: """Unfollow another user if already following. :param user: User model object of user to unfollow. """ if self.is_following(user): self.followed.remove(user)
[docs] def is_following(self, user: User) -> bool: """Check whether following another user. :param user: User model object of user to check if following. :return: Following the user? True or False. """ return ( self.followed.filter(followers.c.followed_id == user.id).count() > 0 )
[docs] def followed_posts(self) -> list[Post]: """Get all posts that the user is following. :return: List of posts that the user is following in descending order. """ followed = Post.query.join( followers, (followers.c.followed_id == Post.user_id) ).filter(followers.c.follower_id == self.id) own = Post.query.filter_by(user_id=self.id) # noinspection PyUnresolvedReferences return followed.union(own).order_by(Post.created.desc())
[docs] def new_messages(self) -> int: """Get the number of new (unread) messages delivered to user. :return: Number of new messages delivered to user. """ last_read_time = self.last_message_read_time or datetime(1000, 1, 1) return ( Message.query.filter_by(recipient=self) .filter(Message.created > last_read_time) .count() )
[docs] def add_notifications( self, name: str, data: dict[str, object] ) -> Notification: """Add user's notifications to database. :param name: Name of the database key. :param data: Saved data. :return: Instantiated ``Notification`` database model. """ self.notifications.filter_by(name=name).delete() # `user` is a backref for `Task` and not defined as a column notification = Notification(name=name, user=self) # type: ignore notification.set_mapping(data) db.session.commit() return notification
[docs] def get_tasks_in_progress(self) -> list[Query]: """Get the currently running tasks triggered by user. :return: List of ``BaseQuery`` objects returned as running tasks. """ return Task.query.filter_by(user=self, complete=False).all()
[docs] @classmethod def resolve_all_names(cls, username: str) -> User: """Manage retrieval of ``User`` object by their username. :param username: Username to search for user under. :raise HTTPError: Raise ``404: Not Found`` if name not resolved. :return: User object. """ user = cls.query.filter_by(username=username).first() if user is None: # noinspection PyUnresolvedReferences usernames = ( Usernames.query.filter_by(username=username) .order_by(Usernames.id.desc()) .first_or_404() ) return cls.query.get(usernames.user_id) return user
[docs] class Post(BaseModel): """Database schema for posts.""" __versioned__: t.Dict[object, object] = {} #: ID of the user that wrote this post. id = db.Column(db.Integer, primary_key=True) #: Title of this post. title = db.Column(db.String, nullable=False) #: Body of this post. body = db.Column(db.String) #: Date that the post was created. created = db.Column(db.DateTime, index=True, default=datetime.utcnow) #: ID of this post. user_id = db.Column(db.Integer, db.ForeignKey("user.id")) #: Date that this post was last edited. edited = db.Column(db.DateTime, default=None)
[docs] def get_version(self, index: int) -> ModelBuilder | None: """Get version of post by index. If no version can be returned a ``404: Not Found`` error will abort instead of raising an ``IndexError``. :param index: Index of version beginning with 0. :return: PostVersion object if within index range, else None. """ try: return self.versions[index] except IndexError: return abort(404)
[docs] @classmethod def get_post( cls, id: int, version: int | None = None, checkauthor: bool = True ) -> Post: """Get post by post's ID or abort with ``404: Not Found.`` Standard behaviour would be to return None, so do not bypass silently. If a version number is provided find version by index. This is a different search to getting post by ID, as database starts at 1, but index always starts at 0. If no version can be returned a ``404: Not Found`` error will abort instead of raising an ``IndexError``. Assign post attributes title and body to the yielded ``Post`` object. The ``PostVersion`` cannot be returned for a restore as it is a different object and will not save when committing database changes. The ``checkauthor`` argument is defined so that the function can be used to get a post without checking the author. This would be useful if we wrote a view to show an individual post on a page where the user doesn't matter, because they are not modifying the post. If the logged in author does not own the post abort with ``403: Forbidden``. :param id: The post's ID. :param version: If provided populate session object with version. :param checkauthor: Rule whether to check for author ID. :return: Post's connection object. """ post = cls.query.filter_by(id=id).first_or_404() if version is not None: post_version = post.get_version(version) if post_version is not None: post.title = post_version.title post.body = post_version.body if checkauthor and post.user_id != current_user.id: abort(403) return post
[docs] class Message(BaseModel): """Database schema for user messages.""" #: Message ID. id = db.Column(db.Integer, primary_key=True) #: ID of the sender of the message. sender_id = db.Column(db.Integer, db.ForeignKey("user.id")) #: ID of the recipient of the message. recipient_id = db.Column(db.Integer, db.ForeignKey("user.id")) #: Message body. body = db.Column(db.String(140)) #: Date that the message was created and sent. created = db.Column(db.DateTime, index=True, default=datetime.utcnow)
[docs] class Notification(BaseModel): """Database schema for notifications.""" #: ID of this notification. id = db.Column(db.Integer, primary_key=True) #: Name of this notification. name = db.Column(db.String(128), index=True) #: ID of the user who this notification belongs to. user_id = db.Column(db.Integer, db.ForeignKey("user.id")) #: Date this notification was created. timestamp = db.Column(db.Float, index=True, default=time) #: Database to JSON dictionary mapping as a string. mapping = db.Column(db.Text)
[docs] def set_mapping(self, mapping: dict[str, object]) -> None: """Set ``dict`` object as ``str`` (JSON). :param mapping: Set object as str. """ self.mapping = json.dumps(mapping)
[docs] def get_mapping(self) -> dict[str, object]: """Get dict representation of JSON data. :return: Dict object of notification data. """ return json.loads(self.mapping)
[docs] class Task(BaseModel): """Database schema for background tasks.""" #: ID of the task. id = db.Column(db.String(36), primary_key=True) #: Name of the task. name = db.Column(db.String(128), index=True) #: Description of the task. description = db.Column(db.String(128)) #: User that this task belongs to. user_id = db.Column(db.Integer, db.ForeignKey("user.id")) #: Whether this task is complete or not. complete = db.Column(db.Boolean, default=False)
[docs] class Usernames(BaseModel): """Database schema for username changes.""" #: ID of the username. id = db.Column(db.Integer, primary_key=True) #: Single username amongst the usernames. username = db.Column(db.String(64), index=True) #: ID of the user that this username belongs to. user_id = db.Column(db.Integer, db.ForeignKey("user.id"))
db.configure_mappers()