diff --git a/findmy/util/session.py b/findmy/util/session.py new file mode 100644 index 0000000..a545292 --- /dev/null +++ b/findmy/util/session.py @@ -0,0 +1,144 @@ +"""Logic related to serializable classes.""" + +from __future__ import annotations + +import random +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, Union + +from findmy.util.abc import Closable, Serializable + +if TYPE_CHECKING: + from pathlib import Path + from types import TracebackType + +_S = TypeVar("_S", bound=Serializable) +_SC = TypeVar("_SC", bound=Union[Serializable, Closable]) + + +class _BaseSessionManager(Generic[_SC]): + """Base class for session managers.""" + + def __init__(self) -> None: + self._sessions: dict[_SC, str | Path | None] = {} + + def _add(self, obj: _SC, path: str | Path | None) -> None: + self._sessions[obj] = path + + def remove(self, obj: _SC) -> None: + self._sessions.pop(obj, None) + + def save(self) -> None: + for obj, path in self._sessions.items(): + if isinstance(obj, Serializable): + obj.to_json(path) + + async def close(self) -> None: + for obj in self._sessions: + if isinstance(obj, Closable): + await obj.close() + + async def save_and_close(self) -> None: + for obj, path in self._sessions.items(): + if isinstance(obj, Serializable): + obj.to_json(path) + if isinstance(obj, Closable): + await obj.close() + + def get_random(self) -> _SC: + if not self._sessions: + msg = "No objects in the session manager." + raise ValueError(msg) + return random.choice(list(self._sessions.keys())) # noqa: S311 + + def __len__(self) -> int: + return len(self._sessions) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> None: + self.save() + + +class MixedSessionManager(_BaseSessionManager[Union[Serializable, Closable]]): + """Allows any Serializable or Closable object.""" + + def new( + self, + c_type: type[_SC], + path: str | Path | None = None, + /, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by instantiating it using its constructor.""" + obj = c_type(*args, **kwargs) + if isinstance(obj, Serializable) and path is not None: + obj.to_json(path) + self._add(obj, path) + return obj + + def add_from_json( + self, + c_type: type[_S], + path: str | Path, + /, + **kwargs: Any, # noqa: ANN401 + ) -> _S: + """Add an object to the manager by deserializing it from its JSON representation.""" + obj = c_type.from_json(path, **kwargs) + self._add(obj, path) + return obj + + def add(self, obj: Serializable | Closable, path: str | Path | None = None, /) -> None: + """Add an object to the session manager.""" + self._add(obj, path) + + +class UniformSessionManager(Generic[_SC], _BaseSessionManager[_SC]): + """Only allows a single type of Serializable object.""" + + def __init__(self, obj_type: type[_SC]) -> None: + """Create a new session manager.""" + super().__init__() + self._obj_type = obj_type + + def new( + self, + path: str | Path | None = None, + /, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by instantiating it using its constructor.""" + obj = self._obj_type(*args, **kwargs) + if isinstance(obj, Serializable) and path is not None: + obj.to_json(path) + self._add(obj, path) + return obj + + def add_from_json( + self, + path: str | Path, + /, + **kwargs: Any, # noqa: ANN401 + ) -> _SC: + """Add an object to the manager by deserializing it from its JSON representation.""" + if not issubclass(self._obj_type, Serializable): + msg = "Can only add objects of type Serializable." + raise TypeError(msg) + obj = self._obj_type.from_json(path, **kwargs) + self._add(obj, path) + return obj + + def add(self, obj: _SC, path: str | Path | None = None, /) -> None: + """Add an object to the session manager.""" + if not isinstance(obj, self._obj_type): + msg = f"Object must be of type {self._obj_type.__name__}" + raise TypeError(msg) + self._add(obj, path)