Source code for coaster.sqlalchemy.comparators
"""
Enhanced query and custom comparators
-------------------------------------
"""
import uuid as uuid_
from flask_sqlalchemy import BaseQuery
from sqlalchemy.ext.hybrid import Comparator
from flask import abort
from ..utils import uuid_from_base58, uuid_from_base64
__all__ = [
'Query',
'SplitIndexComparator',
'SqlSplitIdComparator',
'SqlUuidHexComparator',
'SqlUuidB64Comparator',
'SqlUuidB58Comparator',
]
_marker = object()
[docs]class Query(BaseQuery):
"""
Extends flask_sqlalchemy.BaseQuery to add additional helper methods.
"""
[docs] def notempty(self):
"""
Returns the equivalent of ``bool(query.count())`` but using an efficient
SQL EXISTS function, so the database stops counting after the first result
is found.
"""
return self.session.query(self.exists()).scalar()
[docs] def isempty(self):
"""
Returns the equivalent of ``not bool(query.count())`` but using an efficient
SQL EXISTS function, so the database stops counting after the first result
is found.
"""
return not self.session.query(self.exists()).scalar()
[docs] def one_or_404(self):
"""
Extends :meth:`~sqlalchemy.orm.query.Query.one_or_none` to raise a 404
if no result is found. This method offers a safety net over
:meth:`~flask_sqlalchemy.BaseQuery.first_or_404` as it helps identify
poorly specified queries that could have returned more than one result.
"""
result = self.one_or_none()
if not result:
abort(404)
return result
[docs]class SplitIndexComparator(Comparator):
"""
Base class for comparators that support splitting a string and
comparing with one of the split values.
"""
def __init__(self, expression, splitindex=None):
super().__init__(expression)
self.splitindex = splitindex
def _decode(self, other):
raise NotImplementedError
def __eq__(self, other):
try:
other = self._decode(other)
except (ValueError, TypeError):
return False
return self.__clause_element__() == other
def __ne__(self, other):
try:
other = self._decode(other)
except (ValueError, TypeError):
return True
return self.__clause_element__() != other
[docs] def in_(self, other):
def errordecode(val):
try:
return self._decode(val)
except (ValueError, TypeError):
return _marker
otherlist = (v for v in (errordecode(val) for val in other) if v is not _marker)
return self.__clause_element__().in_(otherlist)
[docs]class SqlSplitIdComparator(SplitIndexComparator):
"""
Allows comparing an id value with a column, useful mostly because of
the splitindex feature, which splits an incoming string along the ``-``
character and picks one of the splits for comparison.
"""
def _decode(self, other):
if other is None:
return
if self.splitindex is not None and isinstance(other, str):
other = int(other.split('-')[self.splitindex])
return other
[docs]class SqlUuidHexComparator(SplitIndexComparator):
"""
Allows comparing UUID fields with hex representations of the UUID
"""
def _decode(self, other):
if other is None:
return
if not isinstance(other, uuid_.UUID):
if self.splitindex is not None:
other = other.split('-')[self.splitindex]
other = uuid_.UUID(other)
return other
[docs]class SqlUuidB64Comparator(SplitIndexComparator):
"""
Allows comparing UUID fields with URL-safe Base64 (BUID) representations
of the UUID
"""
def _decode(self, other):
if other is None:
return
if not isinstance(other, uuid_.UUID):
if self.splitindex is not None:
other = other.split('-')[self.splitindex]
other = uuid_from_base64(other)
return other
[docs]class SqlUuidB58Comparator(SplitIndexComparator):
"""Allows comparing UUID fields with Base58 representations of the UUID"""
def _decode(self, other):
if other is None:
return
if not isinstance(other, uuid_.UUID):
if self.splitindex is not None:
other = other.split('-')[self.splitindex]
other = uuid_from_base58(other)
return other