"""Table"""
# pylint: disable=too-many-arguments,too-many-public-methods
from sqlite3 import Connection, OperationalError
from typing import (
Any,
Generator,
Iterable,
Literal,
NamedTuple,
Optional,
Type,
overload,
)
import weakref
from sqlite_database.functions import ParsedFn, Function
from .utils import crunch
from ._utils import check_iter, check_one, Row
from .column import BuilderColumn, Column
from .errors import TableRemovedError, UnexpectedResultError
from .locals import SQLITEPYTYPES, PLUGINS_PATH
from .query_builder import (
Condition,
extract_single_column,
fetch_columns,
build_select,
build_insert,
build_delete,
build_update,
)
from .signature import op
from .typings import (
Data,
Orders,
Queries,
Query,
TypicalNamedTuple,
_MasterQuery,
OnlyColumn,
SquashedSqueries,
JustAColumn,
)
# Let's add a little bit of 'black' magic here.
_null = Function("__NULL__")()
@classmethod
def get_table(cls): # pylint: disable=missing-function-docstring
return getattr(cls, "_table", None)
[docs]
class Table:
"""Table. Make sure you remember how the table goes."""
_ns: dict[str, Type[NamedTuple]] = {}
def __init__(
self,
parent, # type: ignore
table: str,
__columns: Optional[Iterable[Column]] = None, # type: ignore
) -> None:
if parent.closed:
raise ConnectionError("Connection to database is already closed.")
self._parent_repr = repr(parent)
self._sql: Connection = parent.sql
# pylint: disable-next=protected-access
self._sql_path = parent._path
self._deleted = False
self._force_dirty = False
self._dirty = False
self._table = check_one(table)
self._columns: Optional[list[Column]] = list(__columns) if __columns else None
weakref.finalize(self, self._finalize)
if self._columns is None and table != "sqlite_master":
self._fetch_columns()
def _finalize(self):
pass
def _delete_hook(self):
try:
self.select()
except OperationalError:
self._deleted = True
def _fetch_columns(self):
table = self._table
try:
query, data = build_select(
"sqlite_master", {"type": op == "table", "name": op == table}
)
cursor = self._sql.cursor()
cursor.execute(query, data)
tabl = cursor.fetchone()
if tabl is None:
raise ValueError("What the hell?")
cols = fetch_columns(_MasterQuery(**tabl))
self._columns = cols
return 0
except Exception: # pylint: disable=broad-except
return 1
# def _raw_exec(self, query: str, data: dict[str, Any]):
# """No thread safe :("""
# cursor = self._sql.cursor()
# cursor.execute(query, data)
# return cursor
def _exec(
self,
query: str,
data: dict[str, Any] | list[dict[str, Any]],
which: Literal["execute", "executemany"] = "execute",
):
"""Execute a sql query"""
cursor = self._sql.cursor()
fn = cursor.execute if which == "execute" else cursor.executemany
try:
fn(query, data)
except OperationalError as exc:
exc.add_note(f"SQL query: {query}")
exc.add_note(
f"There's about {1 if isinstance(data, dict) else len(data)} value(s) inserted"
)
raise exc
return cursor
def _control(self):
if self._deleted:
raise TableRemovedError(f"{self._table} is already removed")
def _query_control(self):
if self._dirty and self._force_dirty is False:
self._sql.commit()
self._dirty = False
[docs]
def force_nodelete(self):
"""Force "undelete" table. Used if table was mistakenly assigned as
deleted."""
self._deleted = True
[docs]
def delete(
self,
condition: Condition = None,
limit: int = 0,
order: Optional[Orders] = None,
commit: bool = True,
):
"""Delete row or rows
Args:
condition (Condition, optional): Condition to determine deletion
See `Signature` class about conditional stuff. Defaults to None.
limit (int, optional): Limit deletion by integer. Defaults to 0.
order (Optional[Orders], optional): Order of deletion. Defaults to None.
commit (bool, optional): Commit changes to database (default is true)
Returns:
int: Rows affected
"""
query, data = build_delete(self._table, condition, limit, order) # type: ignore
self._control()
cursor = self._exec(query, data)
rcount = cursor.rowcount
if commit:
self._sql.commit()
else:
self._dirty = True
return rcount
[docs]
def delete_one(self, condition: Condition = None, order: Optional[Orders] = None):
"""Delete a row
Args:
condition (Condition, optional): Conditional to determine deletion.
Defaults to None.
order (Optional[Orders], optional): Order of deletion. Defaults to None.
"""
return self.delete(condition, 1, order)
[docs]
def insert(self, data: Data, commit: bool = True):
"""Insert data to current table
Args:
data (Data): Data to insert. Make sure it's compatible with the table.
commit (bool, optional): Commit data to database.
Returns:
int: Last rowid
"""
query, _ = build_insert(self._table, data) # type: ignore
self._control()
cursor = self._exec(query, data)
rlastrowid = cursor.lastrowid
self._sql.commit()
if commit:
self._sql.commit()
else:
self._dirty = True
return rlastrowid
[docs]
def insert_multiple(self, datas: list[Data], commit: bool = True):
"""Insert multiple values
Args:
datas (Iterable[Data]): Data to be inserted.
commit (bool, optional): Commit data to database
"""
self._control()
query, _ = build_insert(self._table, datas[0]) # type: ignore
self._exec(query, datas, "executemany")
if commit:
self._sql.commit()
else:
self._dirty = True
[docs]
def insert_many(self, datas: list[Data]):
"""Alias to `insert_multiple`"""
return self.insert_multiple(datas)
[docs]
def update(
self,
condition: Condition | None = None,
data: Data | None = None,
limit: int = 0,
order: Optional[Orders] = None,
commit: bool = True,
):
"""Update rows of current table
Args:
data (Data): New data to update
condition (Condition, optional): Condition dictionary.
See `Signature` about how condition works. Defaults to None.
limit (int, optional): Limit updates. Defaults to 0.
order (Optional[Orders], optional): Order of change. Defaults to None.
commit (bool, optional): Commit data to database
Returns:
int: Rows affected
"""
if data is None:
raise ValueError("data parameter must not be None")
query, data = build_update(
self._table, data, condition, limit, order
) # type: ignore
self._control()
cursor = self._exec(query, data)
rcount = cursor.rowcount
if commit:
self._sql.commit()
else:
self._dirty = True
return rcount
[docs]
def update_one(
self,
condition: Condition | None = None,
new_data: Data | None = None,
order: Orders | None = None,
) -> int:
"""Update 1 data only"""
return self.update(condition, new_data, 1, order)
@overload
def select(
self,
condition: Condition = None,
only: OnlyColumn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
squash: Literal[False] = False,
) -> Queries: # type: ignore
pass
@overload
def select(
self,
condition: Condition = None,
only: OnlyColumn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
squash: Literal[True] = True,
) -> SquashedSqueries:
pass
@overload
def select(
self,
condition: Condition = None,
only: ParsedFn = _null,
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
squash: Literal[False] = False,
) -> Any:
pass
[docs]
def select(
self, # pylint: disable=too-many-arguments
condition: Condition = None,
only: OnlyColumn | ParsedFn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
squash: bool = False,
):
"""Select data in current table. Bare .select() returns all data.
Args:
condition (Condition, optional): Conditions to used. Defaults to None.
only: (OnlyColumn, ParsedFn, optional): Select what you want. Default to None.
limit (int, optional): Limit of select. Defaults to 0.
offset (int, optional): Offset. Defaults to 0
order (Optional[Orders], optional): Selection order. Defaults to None.
squash (bool): Is it squashed?
Returns:
Queries: Selected data
"""
self._control()
self._query_control()
query, data = build_select(
self._table, condition, only, limit, offset, order
) # type: ignore
with self._sql:
cursor = self._exec(query, data)
data = cursor.fetchall()
if squash:
return crunch(data)
if isinstance(only, ParsedFn):
return data[0][only.parse_sql()[0]]
return data
@overload
def paginate_select(
self,
condition: Condition = None,
only: OnlyColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
squash: Literal[False] = False,
) -> Generator[Queries, None, None]: # type: ignore
pass
@overload
def paginate_select(
self,
condition: Condition = None,
only: OnlyColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
squash: Literal[True] = True,
) -> Generator[SquashedSqueries, None, None]:
pass
[docs]
def paginate_select(
self,
condition: Condition = None,
only: OnlyColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
squash: bool = False,
):
"""Paginate select
Args:
condition (Condition, optional): Confitions to use. Defaults to None.
only (OnlyColumn, optional): Select what you want. Default to None.
page (int): Which page number be returned first
length (int, optional): Pagination length. Defaults to 10.
order (Optional[Orders], optional): Order. Defaults to None.
Yields:
Generator[Queries, None, None]: Step-by-step paginated result.
"""
if page < 0:
page = 0
order = "desc" if order in ("asc", None) else "asc" # type: ignore
self._control()
self._query_control()
start = page * length
while True:
query, data = build_select(
self._table, condition, only, length, start, order
) # type: ignore
crunched = squash
with self._sql:
cursor = self._exec(query, data)
fetched = cursor.fetchmany(length)
if len(fetched) == 0:
return
if crunched:
fetched = crunch(fetched)
if len(fetched) != length:
yield fetched
return
yield fetched
start += length
@overload
def select_one(
self,
condition: Condition = None,
only: ParsedFn = _null,
order: Optional[Orders] = None,
) -> Any:
pass
@overload
def select_one(
self,
condition: Condition = None,
only: OnlyColumn = "*",
order: Optional[Orders] = None,
) -> Query:
pass
@overload
def select_one(
self,
condition: Condition = None,
only: JustAColumn = "_COLUMN",
order: Optional[Orders] = None,
) -> Any:
pass
[docs]
def select_one(
self,
condition: Condition = None,
only: OnlyColumn | JustAColumn | ParsedFn = "*",
order: Optional[Orders] = None,
):
"""Select one data
Args:
condition (Condition, optional): Condition to use. Defaults to None.
only: (OnlyColumn, optional): Select what you want. Default to None.
order (Optional[Orders], optional): Order of selection. Defaults to None.
Returns:
Any: Selected data
"""
self._control()
self._query_control()
query, data = build_select(
self._table, condition, only, 1, 0, order
) # type: ignore
with self._sql:
cursor = self._exec(query, data)
data = cursor.fetchone()
if isinstance(only, ParsedFn):
return data[only.parse_sql()[0]]
if not data:
return Row()
if isinstance(only, str) and only != "*":
return data[only]
return data
[docs]
def exists(self, condition: Condition = None):
"""Check if data is exists or not.
Args:
condition (Condition, optional): Condition to use. Defaults to None.
"""
data = self.select_one(condition)
if data is None:
return False
return True
[docs]
def get_namespace(self) -> Type[TypicalNamedTuple]:
"""Generate or return pre-existed namespace/table."""
if self._sql_path in PLUGINS_PATH:
plugin = self._sql_path[2:]
raise ValueError(f"Redefining get_namespace required for plugin {plugin}")
if self._ns.get(self._table, None):
return self._ns[self._table]
self._control()
if self._columns:
datatypes = {col.name: SQLITEPYTYPES[col.type] for col in self._columns}
namespace_name = self._table.title() + "Table"
namedtupled = NamedTuple(namespace_name, **datatypes)
setattr(namedtupled, "_table", self)
self._ns[self._table] = namedtupled
return namedtupled
self._fetch_columns()
if self._columns is None:
raise ExceptionGroup(
f"Column misbehave. Table {self._table}",
[
ValueError("Mismatched columns"),
UnexpectedResultError("._fetch_columns() does not change columns."),
],
)
datatypes = {}
for column in self._columns:
datatypes[column.name] = SQLITEPYTYPES[column.type]
namedtupled = NamedTuple(self._table.title() + "Table", **datatypes)
self._ns[self._table] = namedtupled
return namedtupled
[docs]
def columns(self):
"""Table columns"""
if self._columns is None:
raise AttributeError("columns is undefined.")
return tuple(self._columns)
@property
def deleted(self):
"""Is table deleted"""
return self._deleted
@property
def name(self):
"""Table name"""
return self._table
[docs]
def add_column(self, column: Column | BuilderColumn):
"""Add column to table"""
sql = self._sql
column = column.to_column() if isinstance(column, BuilderColumn) else column
if column.primary or column.unique:
raise OperationalError(
"New column cannot have primary or unique constraint"
)
if column.nullable is False and column.default is None:
raise OperationalError(
"New column cannot be not null while default value is \
set to null"
)
if column.default is not None and column.foreign:
raise OperationalError(
"New column must accept null default value if foreign \
constraint is enabled."
)
query = f"alter table {self._table} add column {extract_single_column(column)}"
if self._columns is not None:
self._columns.append(column)
sql.execute(query)
[docs]
def rename_column(self, old_column: str, new_column: str):
"""Rename existing column to new column"""
check_iter((old_column, new_column))
query = f"alter table {self._table} rename column {old_column} to {new_column}"
self._sql.execute(query)
[docs]
def allow_dirty(self):
"""Allow dirty queries"""
self._force_dirty = True
[docs]
def disallow_dirty(self):
"""Disallow dirty queries"""
self._force_dirty = False
[docs]
def commit(self):
"""Commit changes"""
self._sql.commit()
[docs]
def rollback(self):
"""Rollback"""
self._sql.rollback()
self._dirty = False
def __repr__(self) -> str:
return f"<Table({self._table}) -> {self._parent_repr}>"
__all__ = ["Table"]