Source code for acore_df.dataset
# -*- coding: utf-8 -*-
import typing as T
import dataclasses
from pathlib import Path
from urllib.request import urlopen
from functools import cached_property
import polars as pl
import sqlalchemy as sa
from ._version import __version__
from .paths import path_sqlite
T_BASE_ORM_MODEL = T.TypeVar("T_BASE_ORM_MODEL")
T_BASE_DATA_CLASS = T.TypeVar("T_BASE_DATA_CLASS")
[docs]@dataclasses.dataclass
class BaseDataset(T.Generic[T_BASE_ORM_MODEL, T_BASE_DATA_CLASS]):
"""
Base class for a dataset that is backed by a SQL table.
:param name: the dataset name
:param id_col: the primary key column name
:param orm_model: sqlalchemy orm model
:param orm_table: sqlalchemy table
:param data_class: dataclasses data class
:param engine: sqlalchemy engine
"""
name: str = dataclasses.field()
id_col: str = dataclasses.field()
orm_model: T.Type[T_BASE_ORM_MODEL] = dataclasses.field()
orm_table: sa.Table = dataclasses.field()
data_class: T.Type[T_BASE_DATA_CLASS] = dataclasses.field()
engine: sa.Engine = dataclasses.field()
@cached_property
def df(self) -> pl.DataFrame:
"""
Read the entire dataset into a polars DataFrame. This will be cached.
"""
with self.engine.connect() as conn:
return pl.read_database(
query=f"SELECT * FROM {self.orm_table.name}",
connection=conn,
)
@cached_property
def row_map(self) -> T.Dict[T.Union[int, str], T_BASE_DATA_CLASS]:
"""
Create a dictionary mapping the primary key to the data class instance.
"""
dct = dict()
id_col = self.id_col
for row in self.df.to_dicts():
dct[row[id_col]] = self.data_class(**row)
return dct
[docs] def get(self, id: T.Union[int, str]) -> T.Optional[T_BASE_DATA_CLASS]:
"""
Get a data class instance by id value.
.. note::
If you have concern about mutability, you can do the following:
.. code-block:: python
dataset = BaseDataset(...)
data = dataset.data_class(**dataclasses.asdict(dataset.get(id=1)))
:param id: the primary key value.
"""
return self.row_map.get(id)
[docs] def get_by_kvs(self, kvs: T.Dict[str, T.Any]) -> T.List[T_BASE_DATA_CLASS]:
"""
Get a list of data class instances by key-value pairs.
:param kvs: key-value pairs, key is the column name, value is the value to match.
"""
stmt = sa.select(self.orm_model)
conditions = list()
for k, v in kvs.items():
conditions.append(getattr(self.orm_model, k) == v)
if conditions:
stmt = stmt.where(sa.and_(*conditions))
with self.engine.connect() as conn:
results = list()
for row in conn.execute(stmt):
data = self.data_class(**row._asdict())
results.append(data)
return results
[docs]def download_sqlite(path_sqlite: Path = path_sqlite):
"""
Download the sqlite database file from the GitHub release page.
"""
url = f"https://github.com/MacHu-GWU/acore_df-project/releases/download/{__version__}/acore_df.sqlite"
path_sqlite.write_bytes(urlopen(url).read())