Fix: Verify string lengths when creating model objects

This commit is contained in:
Marcus Lindvall 2019-11-20 00:04:37 +01:00
parent 144fdbefb1
commit 1a7ada9d56
2 changed files with 28 additions and 20 deletions

View file

@ -11,13 +11,14 @@ from sqlalchemy.exc import OperationalError
from sqlalchemy.schema import MetaData, Column from sqlalchemy.schema import MetaData, Column
from sqlalchemy.types import Integer from sqlalchemy.types import Integer
from sqlalchemy.orm.collections import InstrumentedList from sqlalchemy.orm.collections import InstrumentedList
from sqlalchemy import event
from sqlservice import ModelBase, as_declarative from sqlservice import ModelBase, as_declarative
from pyjeeves import logging from pyjeeves import logging
from . import db from . import db
from .ext import InstallValidatorListeners from .ext import install_validator_listner
logger = logging.getLogger("PyJeeves." + __name__) logger = logging.getLogger("PyJeeves." + __name__)
@ -38,7 +39,6 @@ except OperationalError as e:
class RawBaseModel(ModelBase): class RawBaseModel(ModelBase):
""" Generalize __init__, __repr__ and to_json """ Generalize __init__, __repr__ and to_json
Based on the models columns , ForetagKod=1""" Based on the models columns , ForetagKod=1"""
__sa_instrumentation_manager__ = InstallValidatorListeners
__to_dict_filter__ = [] __to_dict_filter__ = []
__to_dict_only__ = () __to_dict_only__ = ()
@ -75,9 +75,10 @@ class RawBaseModel(ModelBase):
filters filters
) )
def _map_columns(self, key): @classmethod
if key in self.__column_map__: def _map_columns(cls, key):
return self.__column_map__[key] if key in cls.__column_map__:
return cls.__column_map__[key]
return key return key
def _map_keys(self, data={}): def _map_keys(self, data={}):
@ -185,3 +186,11 @@ class RawBaseModel(ModelBase):
def delete(self): def delete(self):
db.raw_session.delete(self) db.raw_session.delete(self)
db.raw_session.commit() db.raw_session.commit()
# Apply validators for all string attributes in subclasses of RawBaseModel
@event.listens_for(RawBaseModel, 'attribute_instrument')
def receive_attribute_instrument(cls, key, inst):
"listen for the 'attribute_instrument' event"
install_validator_listner(cls, key, inst)

View file

@ -1,5 +1,4 @@
from sqlalchemy.ext.instrumentation import InstrumentationManager
# from sqlalchemy.orm.interfaces import AttributeExtension
from sqlalchemy.orm import ColumnProperty from sqlalchemy.orm import ColumnProperty
from sqlalchemy.types import String from sqlalchemy.types import String
from sqlalchemy import event from sqlalchemy import event
@ -28,18 +27,17 @@ class JeevesDBError(Error):
self.message = message self.message = message
class InstallValidatorListeners(InstrumentationManager): def install_validator_listner(class_, key, inst):
def post_configure_attribute(self, class_, key, inst): """Add validators for any attributes that can be validated."""
"""Add validators for any attributes that can be validated.""" prop = inst.prop
prop = inst.prop # Only interested in simple columns, not relations
# Only interested in simple columns, not relations if isinstance(prop, ColumnProperty) and len(prop.columns) == 1:
if isinstance(prop, ColumnProperty) and len(prop.columns) == 1: col = prop.columns[0]
col = prop.columns[0] # if we have string column with a length, create a length validator listner
# if we have string column with a length, create a length validator listner if isinstance(col.type, String) and col.type.length:
if isinstance(col.type, String) and col.type.length: event.listen(
event.listen( getattr(class_, key), 'set', LengthValidator(
getattr(class_, key), 'set', LengthValidator( col.name, col.type.length), retval=True)
col.name, col.type.length), retval=True)
class LengthValidator(): class LengthValidator():
@ -51,5 +49,6 @@ class LengthValidator():
if len(value) > self.max_length: if len(value) > self.max_length:
raise ValidationError( raise ValidationError(
"%s.%s: Length %d exceeds allowed %d" % ( "%s.%s: Length %d exceeds allowed %d" % (
state.__class__.__name__, self.col_name, len(value), self.max_length)) state.__class__.__name__, state.__class__._map_columns(self.col_name),
len(value), self.max_length))
return value return value