Source code for sqlite_database.database

"""SQLite Database"""

from weakref import finalize, WeakValueDictionary
from sqlite3 import OperationalError, connect
from typing import Iterable, Optional, Mapping

from .locals import PLUGINS_PATH
from ._utils import (
    WithCursor,
    check_iter,
    check_one,
    dict_factory,
    sqlite_multithread_check,
)
from .column import BuilderColumn, Column
from .query_builder import extract_table_creations
from .table import Table
from .errors import DatabaseExistsError, DatabaseMissingError

Columns = Iterable[Column] | Iterable[BuilderColumn]

__all__ = ["Database"]

IGNORE_TABLE_CHECKS = ("sqlite_master", "sqlite_temp_schema", "sqlite_temp_master")


[docs] class Database: """Sqlite3 database, this provide basic integration.""" _active: Mapping[str, "Database"] = WeakValueDictionary() def __new__(cls, path: str, **kwargs): # pylint: disable=unused-argument if path in cls._active: return cls._active[path] self = object.__new__(cls) if path != ":memory:" and cls == Database: cls._active[str(path)] = self # type: ignore return self def __init__(self, path: str, **kwargs) -> None: kwargs["check_same_thread"] = sqlite_multithread_check() != 3 self._path = path if not path in PLUGINS_PATH: self._database = connect(path, **kwargs) self._database.row_factory = dict_factory else: pass self._config = None self._closed = False self._table_instances: dict[str, Table] = {} if not self._closed or self.__dict__.get("_initiated", False) is False: self._finalizer_fn = finalize(self, self.close) self._initiated = True self._kwargs = kwargs def _finalizer(self): self.close()
[docs] def cursor(self) -> WithCursor: """Create cursor""" return self._database.cursor(WithCursor) # type: ignore
[docs] def create_table(self, table: str, columns: Columns): """Create table Args: table (str): Table name columns (Iterable[Column]): Columns for table Returns: Table: Newly created table """ columns = ( column.to_column() if isinstance(column, BuilderColumn) else column for column in columns ) tbquery = extract_table_creations(columns) query = f"create table {table} ({tbquery})" try: cursor = self._database.cursor() cursor.execute(query) self._database.commit() except OperationalError as error: dberror = DatabaseExistsError(f"table {table} already exists.") dberror.add_note(f"{type(error).__name__}: {error!s}") raise dberror from error table_ = self.table(table, columns) table_._deleted = False # pylint: disable=protected-access self._table_instances[table] = table_ return table_
[docs] def delete_table(self, table: str): """Delete an existing table Args: table (str): table name """ check_one(table) table_ = self.table(table) self._database.cursor().execute(f"drop table {table}") # pylint: disable-next=protected-access del self._table_instances[table] table_._delete_hook() # pylint: disable=protected-access
[docs] def table(self, table: str, __columns: Optional[Iterable[Column]] = None): # type: ignore """fetch table""" if self._table_instances.get(table, None) is not None: return self._table_instances[table] try: this_table = Table(self, table, __columns) except OperationalError as exc: dberror = DatabaseMissingError(f"table {table} does not exists") dberror.add_note(f"{type(exc).__name__}: {exc!s}") raise dberror from None self._table_instances[table] = this_table return this_table
[docs] def reset_table(self, table: str, columns: Columns) -> Table: """Reset existing table with new, this rewrote entire table than altering it.""" try: self.delete_table(table) except OperationalError: pass return self.create_table(table, columns)
[docs] def rename_table(self, old_table: str, new_table: str) -> Table: """Rename existing table to a new one.""" check_iter((old_table, new_table)) cursor = self.sql.cursor() cursor.execute(f"alter table {old_table} rename to {new_table}") self.sql.commit() return self.table(new_table)
[docs] def check_table(self, table: str): """Check if table is exists or not.""" if self._path in PLUGINS_PATH: plugin = self._path[2:] raise ValueError(f"Plugin {plugin} must redefine check_table.") check_one(table) if table in IGNORE_TABLE_CHECKS: return True # Let's return true. cursor = self.sql.cursor() cursor.execute( "select name from sqlite_master where type='table' and name=?", (table,) ) if cursor.fetchone(): return True return False
def __repr__(self) -> str: return f"<{type(self).__name__} {id(self)}>"
[docs] def close(self): """Close database""" if self._closed: return self._database.close() for table in self._table_instances.copy(): del self._table_instances[table] if self.path == ":memory:": self._closed = True return if self.path in self._active: del type(self)._active[self.path] # type: ignore self._closed = True
[docs] def tables(self) -> tuple[Table, ...]: """Return tuple containing all table except internal tables""" master = self.table("sqlite_master") listed = [] for table in master.select(): if table.type == "table": listed.append(self.table(table.name)) return tuple(listed)
[docs] def commit(self): """Commit changes to database""" self._database.commit()
[docs] def rollback(self): """Rollback changes""" self._database.rollback()
@property def closed(self): """Is database closed?""" return self._closed @closed.setter def closed(self, __o: bool): """Is database closed?""" if __o: self.close() return raise ValueError("Expected non-false/non-null value") @property def path(self): """Path to SQL Connection""" return self._path or ":memory:" @property def sql(self): """SQL Connection""" return self._database