Skip to content

Commit

Permalink
feat: migrations (#4)
Browse files Browse the repository at this point in the history
* feat: migrations

* fix: mypy

* fix: migrations exclude __init__ && new migration starts with migration_

* version: 1.0.0
  • Loading branch information
roman-right authored Nov 14, 2022
1 parent 7bf74eb commit 6ba0d7d
Show file tree
Hide file tree
Showing 58 changed files with 2,218 additions and 56 deletions.
7 changes: 6 additions & 1 deletion bunnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from bunnet.migrations.controllers.free_fall import free_fall_migration
from bunnet.migrations.controllers.iterative import iterative_migration
from bunnet.odm.actions import (
before_event,
after_event,
Expand All @@ -24,7 +26,7 @@
from bunnet.odm.views import View
from bunnet.odm.union_doc import UnionDoc

__version__ = "0.1.1"
__version__ = "1.0.0"
__all__ = [
# ODM
"Document",
Expand Down Expand Up @@ -52,4 +54,7 @@
"Link",
"WriteRules",
"DeleteRules",
# Migrations
"iterative_migration",
"free_fall_migration",
]
Empty file added bunnet/executors/__init__.py
Empty file.
172 changes: 172 additions & 0 deletions bunnet/executors/migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging
import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, Any

import click
import toml
from pydantic import BaseSettings

from bunnet.migrations import template
from bunnet.migrations.database import DBHandler
from bunnet.migrations.models import RunningMode, RunningDirections
from bunnet.migrations.runner import MigrationNode

logging.basicConfig(format="%(message)s", level=logging.INFO)


def toml_config_settings_source(settings: BaseSettings) -> Dict[str, Any]:
path = Path("pyproject.toml")
if path.is_file():
return (
toml.load(path)
.get("tool", {})
.get("bunnet", {})
.get("migrations", {})
)
return {}


class MigrationSettings(BaseSettings):
direction: RunningDirections = RunningDirections.FORWARD
distance: int = 0
connection_uri: str
database_name: str
path: Path
allow_index_dropping: bool = False

class Config:
env_prefix = "bunnet_"
fields = {
"connection_uri": {
"env": [
"uri",
"connection_uri",
"connection_string",
"mongodb_dsn",
"mongodb_uri",
]
},
"db": {"env": ["db", "db_name", "database_name"]},
}

@classmethod
def customise_sources(
cls,
init_settings,
env_settings,
file_secret_settings,
):
return (
init_settings,
toml_config_settings_source,
env_settings,
file_secret_settings,
)


@click.group()
def migrations():
pass


def run_migrate(settings: MigrationSettings):
DBHandler.set_db(settings.connection_uri, settings.database_name)
root = MigrationNode.build(settings.path)
mode = RunningMode(
direction=settings.direction, distance=settings.distance
)
root.run(mode=mode, allow_index_dropping=settings.allow_index_dropping)


@migrations.command()
@click.option(
"--forward",
"direction",
required=False,
flag_value="FORWARD",
help="Roll the migrations forward. This is default",
)
@click.option(
"--backward",
"direction",
required=False,
flag_value="BACKWARD",
help="Roll the migrations backward",
)
@click.option(
"-d",
"--distance",
required=False,
help="How many migrations should be done since the current? "
"0 - all the migrations. Default is 0",
)
@click.option(
"-uri",
"--connection-uri",
required=False,
type=str,
help="MongoDB connection URI",
)
@click.option(
"-db", "--database_name", required=False, type=str, help="DataBase name"
)
@click.option(
"-p",
"--path",
required=False,
type=str,
help="Path to the migrations directory",
)
@click.option(
"--allow-index-dropping/--forbid-index-dropping",
required=False,
default=False,
help="if allow-index-dropping is set, Beanie will drop indexes from your collection",
)
def migrate(
direction,
distance,
connection_uri,
database_name,
path,
allow_index_dropping,
):
settings_kwargs = {}
if direction:
settings_kwargs["direction"] = direction
if distance:
settings_kwargs["distance"] = distance
if connection_uri:
settings_kwargs["connection_uri"] = connection_uri
if database_name:
settings_kwargs["database_name"] = database_name
if path:
settings_kwargs["path"] = path
if allow_index_dropping:
settings_kwargs["allow_index_dropping"] = allow_index_dropping
settings = MigrationSettings(**settings_kwargs)

run_migrate(settings)


@migrations.command()
@click.option("-n", "--name", required=True, type=str, help="Migration name")
@click.option(
"-p",
"--path",
required=True,
type=str,
help="Path to the migrations directory",
)
def new_migration(name, path):
path = Path(path)
ts_string = datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"migration_{ts_string}_{name}.py"

shutil.copy(template.__file__, path / file_name)


if __name__ == "__main__":
migrations()
Empty file added bunnet/migrations/__init__.py
Empty file.
Empty file.
15 changes: 15 additions & 0 deletions bunnet/migrations/controllers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from typing import List, Type

from bunnet.odm.documents import Document


class BaseMigrationController(ABC):
@abstractmethod
def run(self, session):
pass

@property
@abstractmethod
def models(self) -> List[Type[Document]]:
pass
28 changes: 28 additions & 0 deletions bunnet/migrations/controllers/free_fall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from inspect import signature
from typing import Type, List

from bunnet.odm.documents import Document
from bunnet.migrations.controllers.base import BaseMigrationController


def free_fall_migration(document_models: List[Type[Document]]):
class FreeFallMigrationController(BaseMigrationController):
def __init__(self, function):
self.function = function
self.function_signature = signature(function)
self.document_models = document_models

def __call__(self, *args, **kwargs):
pass

@property
def models(self) -> List[Type[Document]]:
return self.document_models

def run(self, session):
function_kwargs = {"session": session}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
self.function(**function_kwargs)

return FreeFallMigrationController
120 changes: 120 additions & 0 deletions bunnet/migrations/controllers/iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from inspect import signature, isclass
from typing import Type, Optional, Union, List

from bunnet.migrations.utils import update_dict
from bunnet.migrations.controllers.base import BaseMigrationController
from bunnet.odm.documents import Document


class DummyOutput:
def __init__(self):
super(DummyOutput, self).__setattr__("_internal_structure_dict", {})

def __setattr__(self, key, value):
self._internal_structure_dict[key] = value

def __getattr__(self, item):
try:
return self._internal_structure_dict[item]
except KeyError:
self._internal_structure_dict[item] = DummyOutput()
return self._internal_structure_dict[item]

def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
if to_parse is None:
to_parse = self
input_dict = (
to_parse._internal_structure_dict
if isinstance(to_parse, DummyOutput)
else to_parse
)
result_dict = {}
for key, value in input_dict.items():
if isinstance(value, (DummyOutput, dict)):
result_dict[key] = self.dict(to_parse=value)
else:
result_dict[key] = value
return result_dict


def iterative_migration(
document_models: Optional[List[Type[Document]]] = None,
batch_size: int = 10000,
):
class IterativeMigration(BaseMigrationController):
def __init__(self, function):
self.function = function
self.function_signature = signature(function)
input_signature = self.function_signature.parameters.get(
"input_document"
)
if input_signature is None:
raise RuntimeError("input_signature must not be None")
self.input_document_model: Type[
Document
] = input_signature.annotation
output_signature = self.function_signature.parameters.get(
"output_document"
)
if output_signature is None:
raise RuntimeError("output_signature must not be None")
self.output_document_model: Type[
Document
] = output_signature.annotation

if (
not isclass(self.input_document_model)
or not issubclass(self.input_document_model, Document)
or not isclass(self.output_document_model)
or not issubclass(self.output_document_model, Document)
):
raise TypeError(
"input_document and output_document "
"must have annotation of Document subclass"
)

self.batch_size = batch_size

def __call__(self, *args, **kwargs):
pass

@property
def models(self) -> List[Type[Document]]:
preset_models = document_models
if preset_models is None:
preset_models = []
return preset_models + [
self.input_document_model,
self.output_document_model,
]

def run(self, session):
output_documents = []
for input_document in self.input_document_model.find_all():
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
"output_document": output,
}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
self.function(**function_kwargs)
output_dict = input_document.dict()
update_dict(output_dict, output.dict())
output_document = self.output_document_model.parse_obj(
output_dict
)
output_documents.append(output_document)

if len(output_documents) == self.batch_size:
self.output_document_model.replace_many(
documents=output_documents
)
output_documents = []

if output_documents:
self.output_document_model.replace_many(
documents=output_documents
)

return IterativeMigration
16 changes: 16 additions & 0 deletions bunnet/migrations/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pymongo import MongoClient


class DBHandler:
@classmethod
def set_db(cls, uri, db_name):
cls.client = MongoClient(uri)
cls.database = cls.client[db_name]

@classmethod
def get_cli(cls):
return cls.client if hasattr(cls, "client") else None

@classmethod
def get_db(cls):
return cls.database if hasattr(cls, "database") else None
Loading

0 comments on commit 6ba0d7d

Please sign in to comment.