224 lines
7.2 KiB
Python
224 lines
7.2 KiB
Python
"""
|
|
Define an Abstract Base Class (ABC) for models
|
|
"""
|
|
from decimal import Decimal
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy.sql.expression import and_
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
from sqlalchemy.schema import MetaData, Column
|
|
from sqlalchemy.types import Integer
|
|
from sqlalchemy.orm.collections import InstrumentedList
|
|
from sqlalchemy import event, orm
|
|
from sqlalchemy.orm import Session
|
|
|
|
from sqlservice import ModelBase, as_declarative
|
|
|
|
from pyjeeves import logging
|
|
|
|
from . import db
|
|
from .ext import install_validator_listner
|
|
|
|
logger = logging.getLogger("PyJeeves." + __name__)
|
|
|
|
logger.info("Reading Jeeves DB structure")
|
|
|
|
meta = MetaData()
|
|
try:
|
|
# TODO: Split raw.py and reflect tables on separate module loads?
|
|
meta.reflect(bind=db.raw_session.connection(),
|
|
only=['ar', 'ars', 'arsh', 'arean', 'xae', 'xare', 'fr', 'kus', 'x1k',
|
|
'oh', 'orp', 'lp', 'vg', 'xp', 'xm', 'prh', 'prl',
|
|
'kp', 'kpw', 'cr', 'X4', 'xw', 'X1', 'jfbs', 'lrfb',
|
|
'JAPP_EWMS_Item_Replenishment_Levels'])
|
|
except OperationalError as e:
|
|
logger.error("Failed to read Jeeves DB structure")
|
|
raise e
|
|
|
|
|
|
|
|
@event.listens_for(Session, "do_orm_execute")
|
|
def _add_filtering_criteria(execute_state):
|
|
"""Intercept all ORM queries. Add a with_loader_criteria option to all
|
|
of them.
|
|
|
|
This option applies to SELECT queries and adds a global WHERE criteria
|
|
(or as appropriate ON CLAUSE criteria for join targets)
|
|
to all objects of a certain class or superclass.
|
|
|
|
"""
|
|
|
|
# the with_loader_criteria automatically applies itself to
|
|
# relationship loads as well including lazy loads. So if this is
|
|
# a relationship load, assume the option was set up from the top level
|
|
# query.
|
|
|
|
# TODO: Make configurable if repo made pub
|
|
company_code = execute_state.execution_options.get("company_code", 1)
|
|
|
|
if (
|
|
not execute_state.is_column_load
|
|
and not execute_state.is_relationship_load
|
|
# and not execute_state.execution_options.get("include_private", False)
|
|
):
|
|
execute_state.statement = execute_state.statement.options(
|
|
orm.with_loader_criteria(
|
|
RawBaseModel,
|
|
lambda cls: cls.ForetagKod == company_code,
|
|
include_aliases=True,
|
|
)
|
|
)
|
|
|
|
|
|
@as_declarative(metadata=meta)
|
|
class RawBaseModel(ModelBase):
|
|
""" Generalize __init__, __repr__ and to_json
|
|
Based on the models columns , ForetagKod=1"""
|
|
|
|
__to_dict_filter__ = []
|
|
__to_dict_only__ = ()
|
|
__column_map__ = {}
|
|
__reversed_column_map__ = lambda self: {v: k for k, v in self.__column_map__.items()} # noqa
|
|
|
|
__table_args__ = {
|
|
'extend_existing': True
|
|
}
|
|
|
|
__dict_args__ = {
|
|
'adapters': {
|
|
datetime: lambda value, col, *_: value.strftime('%Y-%m-%d %H:%M'),
|
|
Decimal: lambda value, col, *_: float(value) # "{:.2f}".format(value)
|
|
}
|
|
}
|
|
|
|
ForetagKod = Column(Integer, primary_key=True)
|
|
|
|
def __init__(self, data=None, **kargs):
|
|
if data:
|
|
data = self._map_keys(data)
|
|
self.set(**kargs)
|
|
|
|
@classmethod
|
|
def _map_columns(cls, key):
|
|
if key in cls.__column_map__:
|
|
return cls.__column_map__[key]
|
|
return key
|
|
|
|
def _map_keys(self, data={}):
|
|
rv = {}
|
|
for key, value in self.__reversed_column_map__().items():
|
|
if key in data:
|
|
rv[value] = data[key]
|
|
for key, value in data.items():
|
|
if hasattr(self, key):
|
|
if key in self.relationships().keys():
|
|
rv[key] = self._map_relationship_keys(key, value)
|
|
else:
|
|
rv[key] = value
|
|
return rv
|
|
|
|
def _map_relationship_keys(self, field, value):
|
|
"""Get model relationships fields value. Almost a copy from SQLService ModelBase"""
|
|
relation_attr = getattr(self.__class__, field)
|
|
uselist = relation_attr.property.uselist
|
|
relation_class = relation_attr.property.mapper.class_
|
|
|
|
if uselist:
|
|
if not isinstance(value, (list, tuple)): # pragma: no cover
|
|
value = [value]
|
|
|
|
# Convert each value instance to relationship class.
|
|
value = [relation_class(val) if not isinstance(val, relation_class)
|
|
else val
|
|
for val in value]
|
|
elif value and isinstance(value, dict):
|
|
# Convert single value object to relationship class.
|
|
value = relation_class(value)
|
|
elif not value and isinstance(value, dict):
|
|
# If value is {} and we're trying to update a relationship
|
|
# attribute, then we need to set to None to nullify relationship
|
|
# value.
|
|
value = None
|
|
|
|
return value
|
|
|
|
def descriptors_to_dict(self):
|
|
"""Return a ``dict`` that maps data loaded in :attr:`__dict__` to this
|
|
model's descriptors. The data contained in :attr:`__dict__` represents
|
|
the model's state that has been loaded from the database. Accessing
|
|
values in :attr:`__dict__` will prevent SQLAlchemy from issuing
|
|
database queries for any ORM data that hasn't been loaded from the
|
|
database already.
|
|
|
|
Note:
|
|
The ``dict`` returned will contain model instances for any
|
|
relationship data that is loaded. To get a ``dict`` containing all
|
|
non-ORM objects, use :meth:`to_dict`.
|
|
|
|
Returns:
|
|
dict
|
|
"""
|
|
descriptors = self.descriptors()
|
|
|
|
return { # Expose hybrid_property extension
|
|
**{key: getattr(self, key) for key in descriptors.keys()
|
|
if isinstance(descriptors.get(key), hybrid_property)},
|
|
# and return all items included in descriptors
|
|
**{key: value for key, value in self.__dict__.items()
|
|
if key in descriptors}}
|
|
|
|
def to_dict(self):
|
|
rv = super().to_dict()
|
|
|
|
if self.__to_dict_only__:
|
|
return {
|
|
self._map_columns(key): rv[key]
|
|
for key in rv
|
|
if key in self.__to_dict_only__
|
|
}
|
|
|
|
for _filter in self.__to_dict_filter__:
|
|
rv.pop(_filter)
|
|
|
|
return rv
|
|
|
|
def from_dict(self, data={}):
|
|
for key, value in self.__reversed_column_map__().items():
|
|
if key in data:
|
|
self[value] = data[key]
|
|
for key, value in data.items():
|
|
if hasattr(self, key):
|
|
if isinstance(self[key], InstrumentedList):
|
|
pass
|
|
else:
|
|
self[key] = value
|
|
return self
|
|
|
|
def merge(self):
|
|
db.raw_session.merge(self)
|
|
return self
|
|
|
|
def commit(self):
|
|
db.raw_session.commit()
|
|
|
|
def save(self):
|
|
db.raw_session.add(self)
|
|
db.raw_session.commit()
|
|
return self
|
|
|
|
def delete(self):
|
|
db.raw_session.delete(self)
|
|
db.raw_session.commit()
|
|
|
|
|
|
# Apply validators for all string attributes in subclasses of RawBaseModel
|
|
@event.listens_for(RawBaseModel, 'attribute_instrument')
|
|
def receive_attribute_instrument(cls, key, inst):
|
|
"listen for the 'attribute_instrument' event"
|
|
|
|
install_validator_listner(cls, key, inst)
|
|
|
|
|
|
db.set_model_class(RawBaseModel)
|