#29 Refactor configuration to use environ module

Merged
getty merged 5 commits from getty/gr/config into getty/master 4 years ago

+ 1 - 1
lament-configuration.py

@@ -4,5 +4,5 @@ from lc.app import app
 import lc.config
 import lc.config
 import lc.model
 import lc.model
 
 
-lc.config.db.init(os.getenv("DB_LOC", "test.db"))
+lc.config.app.init_db()
 lc.model.create_tables()
 lc.model.create_tables()

+ 1 - 1
lc/app.py

@@ -10,7 +10,7 @@ import lc.request as r
 import lc.view as v
 import lc.view as v
 from lc.web import Endpoint, endpoint, render
 from lc.web import Endpoint, endpoint, render
 
 
-app = c.app
+app = c.app.app
 
 
 
 
 @endpoint("/")
 @endpoint("/")

+ 55 - 6
lc/config.py

@@ -1,16 +1,65 @@
+from dataclasses import dataclass
 import os
 import os
 import sys
 import sys
+from typing import Any
 
 
+import environ
 import flask
 import flask
 import itsdangerous
 import itsdangerous
 import playhouse.sqlite_ext
 import playhouse.sqlite_ext
 
 
-db = playhouse.sqlite_ext.SqliteExtDatabase(None)
-per_page = 50
-serializer = itsdangerous.URLSafeTimedSerializer(os.getenv("SECRET_KEY", "TEMP KEY"))
-app_path = os.environ["APP_PATH"].strip()
-app = flask.Flask(__name__)
-app.secret_key = os.getenv("SECRET_KEY", "ARGLBARGL")
+
+@environ.config(prefix="LC")
+class Config:
+    secret_key = environ.var()
+    app_path = environ.var()
+    db_path = environ.var()
+    static_path = environ.var("static")
+
+
+@dataclass
+class App:
+    config: Config
+    app: flask.Flask
+    db: playhouse.sqlite_ext.SqliteExtDatabase
+    serializer: itsdangerous.URLSafeTimedSerializer
+    per_page: int = 50
+
+    @staticmethod
+    def from_env() -> "App":
+        config = environ.to_config(Config)
+        app = flask.Flask(
+            __name__, static_folder=os.path.join(os.getcwd(), config.static_path),
+        )
+        app.secret_key = config.secret_key
+        return App(
+            config=config,
+            db=playhouse.sqlite_ext.SqliteExtDatabase(None),
+            serializer=itsdangerous.URLSafeTimedSerializer(config.secret_key),
+            app=app,
+        )
+
+    def init_db(self):
+        self.db.init(self.config.db_path)
+
+    def in_memory_db(self):
+        try:
+            self.db.close()
+        except:
+            pass
+        self.db.init(":memory:")
+
+    def close_db(self):
+        self.db.close()
+
+    def serialize_token(self, obj: Any) -> str:
+        return self.serializer.dumps(obj)
+
+    def load_token(self, token: str) -> Any:
+        return self.serializer.loads(token)
+
+
+app = App.from_env()
 
 
 if sys.stderr.isatty():
 if sys.stderr.isatty():
 
 

+ 13 - 7
lc/model.py

@@ -1,3 +1,4 @@
+from contextlib import contextmanager
 from dataclasses import dataclass
 from dataclasses import dataclass
 import datetime
 import datetime
 import json
 import json
@@ -14,11 +15,16 @@ import lc.view as v
 
 
 class Model(peewee.Model):
 class Model(peewee.Model):
     class Meta:
     class Meta:
-        database = c.db
+        database = c.app.db
 
 
     def to_dict(self) -> dict:
     def to_dict(self) -> dict:
         return playhouse.shortcuts.model_to_dict(self)
         return playhouse.shortcuts.model_to_dict(self)
 
 
+    @contextmanager
+    def atomic(self):
+        with c.app.db.atomic():
+            yield
+
 
 
 class User(Model):
 class User(Model):
     """
     """
@@ -87,7 +93,7 @@ class User(Model):
         query = Link.select().where(
         query = Link.select().where(
             (Link.user == self) & ((self == as_user) | (Link.private == False))
             (Link.user == self) & ((self == as_user) | (Link.private == False))
         )
         )
-        links = query.order_by(-Link.created).paginate(page, c.per_page)
+        links = query.order_by(-Link.created).paginate(page, c.app.per_page)
         link_views = [l.to_view(as_user) for l in links]
         link_views = [l.to_view(as_user) for l in links]
         pagination = v.Pagination.from_total(page, query.count())
         pagination = v.Pagination.from_total(page, query.count())
         return link_views, pagination
         return link_views, pagination
@@ -138,7 +144,7 @@ class User(Model):
 
 
                 tags[t] = Tag.get_or_create_tag(self, t)
                 tags[t] = Tag.get_or_create_tag(self, t)
 
 
-        with c.db.atomic():
+        with self.atomic():
             for l in links:
             for l in links:
                 try:
                 try:
                     time = datetime.datetime.strptime(l["time"], "%Y-%m-%dT%H:%M:%SZ")
                     time = datetime.datetime.strptime(l["time"], "%Y-%m-%dT%H:%M:%SZ")
@@ -212,7 +218,7 @@ class Link(Model):
         return l
         return l
 
 
     def update_from_request(self, user: User, link: r.Link):
     def update_from_request(self, user: User, link: r.Link):
-        with c.db.atomic():
+        with self.atomic():
             req_tags = set(link.tags)
             req_tags = set(link.tags)
 
 
             for hastag in self.tags:  # type: ignore
             for hastag in self.tags:  # type: ignore
@@ -277,7 +283,7 @@ class Tag(Model):
         )
         )
         links = [
         links = [
             ht.link.to_view(as_user)
             ht.link.to_view(as_user)
-            for ht in query.order_by(-Link.created).paginate(page, c.per_page)
+            for ht in query.order_by(-Link.created).paginate(page, c.app.per_page)
         ]
         ]
         pagination = v.Pagination.from_total(page, query.count())
         pagination = v.Pagination.from_total(page, query.count())
         return links, pagination
         return links, pagination
@@ -351,7 +357,7 @@ class UserInvite(Model):
     @staticmethod
     @staticmethod
     def manufacture(creator: User) -> "UserInvite":
     def manufacture(creator: User) -> "UserInvite":
         now = datetime.datetime.now()
         now = datetime.datetime.now()
-        token = c.serializer.dumps(
+        token = c.app.serialize_token(
             {"created_at": now.timestamp(), "created_by": creator.name,}
             {"created_at": now.timestamp(), "created_by": creator.name,}
         )
         )
         return UserInvite.create(
         return UserInvite.create(
@@ -373,4 +379,4 @@ MODELS = [
 
 
 
 
 def create_tables():
 def create_tables():
-    c.db.create_tables(MODELS, safe=True)
+    c.app.db.create_tables(MODELS, safe=True)

+ 1 - 1
lc/request.py

@@ -35,7 +35,7 @@ class User(Request):
         return cls(name=form["username"], password=form["password"],)
         return cls(name=form["username"], password=form["password"],)
 
 
     def to_token(self) -> str:
     def to_token(self) -> str:
-        return c.serializer.dumps({"name": self.name})
+        return c.app.serialize_token({"name": self.name})
 
 
 
 
 @dataclass_json
 @dataclass_json

+ 2 - 2
lc/view.py

@@ -26,7 +26,7 @@ class Pagination(View):
 
 
     @classmethod
     @classmethod
     def from_total(cls, current, total) -> "Pagination":
     def from_total(cls, current, total) -> "Pagination":
-        return cls(current=current, last=((total - 1) // c.per_page) + 1,)
+        return cls(current=current, last=((total - 1) // c.app.per_page) + 1,)
 
 
 
 
 @dataclass
 @dataclass
@@ -55,7 +55,7 @@ class Config(View):
     def bookmarklet_link(self):
     def bookmarklet_link(self):
         return (
         return (
             "javascript:(function(){window.open(`"
             "javascript:(function(){window.open(`"
-            + c.app_path
+            + c.app.config.app_path
             + "/u/"
             + "/u/"
             + self.username
             + self.username
             + "/l?name=${document.title}&url=${document.URL}`);})();"
             + "/l?name=${document.title}&url=${document.URL}`);})();"

+ 6 - 5
lc/web.py

@@ -19,7 +19,8 @@ class ApiOK:
 
 
 
 
 class Endpoint:
 class Endpoint:
-    __slots__ = ('user',)
+    __slots__ = ("user",)
+
     def __init__(self):
     def __init__(self):
         self.user = None
         self.user = None
 
 
@@ -38,7 +39,7 @@ class Endpoint:
         # if that exists and we can deserialize it, then make sure
         # if that exists and we can deserialize it, then make sure
         # it contains a valid user password, too
         # it contains a valid user password, too
         try:
         try:
-            payload = c.serializer.loads(token)
+            payload = c.app.load_token(token)
         except:
         except:
             # TODO: be more specific about what errors we're catching
             # TODO: be more specific about what errors we're catching
             # here!
             # here!
@@ -179,7 +180,7 @@ def endpoint(route: str):
         func.__name__ = endpoint_class.__name__
         func.__name__ = endpoint_class.__name__
 
 
         # finally, use the Flask routing machinery to register our callback
         # finally, use the Flask routing machinery to register our callback
-        return c.app.route(route, methods=methods)(func)
+        return c.app.app.route(route, methods=methods)(func)
 
 
     return do_endpoint
     return do_endpoint
 
 
@@ -194,7 +195,7 @@ def render(name: str, data: Optional[v.View] = None) -> str:
     return renderer.render(template, data or {})
     return renderer.render(template, data or {})
 
 
 
 
-@c.app.errorhandler(404)
+@c.app.app.errorhandler(404)
 def handle_404(e):
 def handle_404(e):
     user = Endpoint.just_get_user()
     user = Endpoint.just_get_user()
     url = flask.request.path
     url = flask.request.path
@@ -203,7 +204,7 @@ def handle_404(e):
     return render("main", page)
     return render("main", page)
 
 
 
 
-@c.app.errorhandler(500)
+@c.app.app.errorhandler(500)
 def handle_500(e):
 def handle_500(e):
     user = Endpoint.just_get_user()
     user = Endpoint.just_get_user()
     c.log(f"Internal error: {e}")
     c.log(f"Internal error: {e}")

+ 22 - 2
poetry.lock

@@ -16,7 +16,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
 version = "1.3.0"
 version = "1.3.0"
 
 
 [[package]]
 [[package]]
-category = "dev"
+category = "main"
 description = "Classes Without Boilerplate"
 description = "Classes Without Boilerplate"
 name = "attrs"
 name = "attrs"
 optional = false
 optional = false
@@ -99,6 +99,22 @@ typing-inspect = ">=0.4.0"
 [package.extras]
 [package.extras]
 dev = ["pytest", "ipython", "mypy (>=0.710)", "hypothesis", "portray", "flake8", "simplejson"]
 dev = ["pytest", "ipython", "mypy (>=0.710)", "hypothesis", "portray", "flake8", "simplejson"]
 
 
+[[package]]
+category = "main"
+description = "Boilerplate-free configuration with env variables."
+name = "environ-config"
+optional = false
+python-versions = "*"
+version = "20.1.0"
+
+[package.dependencies]
+attrs = ">=17.4.0"
+
+[package.extras]
+dev = ["pytest", "coverage", "sphinx", "sphinx-rtd-theme", "pre-commit"]
+docs = ["sphinx", "sphinx-rtd-theme"]
+tests = ["pytest", "coverage"]
+
 [[package]]
 [[package]]
 category = "main"
 category = "main"
 description = "A simple framework for building complex web applications."
 description = "A simple framework for building complex web applications."
@@ -440,7 +456,7 @@ dev = ["pytest", "pytest-timeout", "coverage", "tox", "sphinx", "pallets-sphinx-
 watchdog = ["watchdog"]
 watchdog = ["watchdog"]
 
 
 [metadata]
 [metadata]
-content-hash = "199d426c3947fdf20a0945d853d9eaad79d800e6053dc4fd10228c28b21a9457"
+content-hash = "f40a6c27fd58db00cd96ef01da21d52516a783004dcd49d72787d734218dd8a9"
 python-versions = "^3.8"
 python-versions = "^3.8"
 
 
 [metadata.files]
 [metadata.files]
@@ -480,6 +496,10 @@ dataclasses-json = [
     {file = "dataclasses-json-0.4.2.tar.gz", hash = "sha256:65ac9ae2f7ec152ee01bf42c8c024736d4cd6f6fb761502dec92bd553931e3d9"},
     {file = "dataclasses-json-0.4.2.tar.gz", hash = "sha256:65ac9ae2f7ec152ee01bf42c8c024736d4cd6f6fb761502dec92bd553931e3d9"},
     {file = "dataclasses_json-0.4.2-py3-none-any.whl", hash = "sha256:dbb53ebbac30ef45f44f5f436b21bd5726a80a14e1a193958864229100271372"},
     {file = "dataclasses_json-0.4.2-py3-none-any.whl", hash = "sha256:dbb53ebbac30ef45f44f5f436b21bd5726a80a14e1a193958864229100271372"},
 ]
 ]
+environ-config = [
+    {file = "environ-config-20.1.0.tar.gz", hash = "sha256:3167feda073bd3cd3457a3e5fa7c2836b6574c046cd0dcd79385ce3284e837bd"},
+    {file = "environ_config-20.1.0-py2.py3-none-any.whl", hash = "sha256:0e300307520c1e6a5424018b7f70246a2c7f4cb5cc5cbbebccc6a982eb1767cb"},
+]
 flask = [
 flask = [
     {file = "Flask-1.1.2-py2.py3-none-any.whl", hash = "sha256:8a4fdd8936eba2512e9c85df320a37e694c93945b33ef33c89946a340a238557"},
     {file = "Flask-1.1.2-py2.py3-none-any.whl", hash = "sha256:8a4fdd8936eba2512e9c85df320a37e694c93945b33ef33c89946a340a238557"},
     {file = "Flask-1.1.2.tar.gz", hash = "sha256:4efa1ae2d7c9865af48986de8aeb8504bf32c7f3d6fdc9353d34b21f4b127060"},
     {file = "Flask-1.1.2.tar.gz", hash = "sha256:4efa1ae2d7c9865af48986de8aeb8504bf32c7f3d6fdc9353d34b21f4b127060"},

+ 1 - 0
pyproject.toml

@@ -12,6 +12,7 @@ pystache = "^0.5.4"
 dataclasses-json = "^0.4.2"
 dataclasses-json = "^0.4.2"
 passlib = "^1.7.2"
 passlib = "^1.7.2"
 itsdangerous = "^1.1.0"
 itsdangerous = "^1.1.0"
+environ-config = "^20.1.0"
 
 
 [tool.poetry.dev-dependencies]
 [tool.poetry.dev-dependencies]
 pytest = "^5.4.1"
 pytest = "^5.4.1"

+ 0 - 1
scripts/populate.py

@@ -10,7 +10,6 @@ import lc.request as r
 
 
 
 
 def main():
 def main():
-    c.db.init("test.db")
     m.create_tables()
     m.create_tables()
 
 
     u = m.User.get_or_none(name="gdritter")
     u = m.User.get_or_none(name="gdritter")

lc/static/jquery-3.5.0.min.js → static/jquery-3.5.0.min.js


lc/static/lc.js → static/lc.js


lc/static/lc_128.png → static/lc_128.png


lc/static/lc_16.png → static/lc_16.png


lc/static/lc_32.png → static/lc_32.png


lc/static/lc_64.png → static/lc_64.png


lc/static/leaguespartan-bold.ttf → static/leaguespartan-bold.ttf


lc/static/main.css → static/main.css


+ 15 - 0
stubs/environ.py

@@ -0,0 +1,15 @@
+from typing import Any, TypeVar, Type
+
+T = TypeVar("T")
+
+
+def config(prefix: str) -> Any:
+    pass
+
+
+def var(default: str = "") -> str:
+    pass
+
+
+def to_config(klass: Type[T]) -> T:
+    pass

+ 3 - 1
tasks.py

@@ -15,7 +15,9 @@ def run(c, port=8080, host="127.0.0.1"):
         f"poetry run python -m flask run -p {port} -h {host}",
         f"poetry run python -m flask run -p {port} -h {host}",
         env={
         env={
             "FLASK_APP": "lament-configuration.py",
             "FLASK_APP": "lament-configuration.py",
-            "APP_PATH": f"http://{host}:{port}",
+            "LC_APP_PATH": f"http://{host}:{port}",
+            "LC_DB_PATH": f"test.db",
+            "LC_SECRET_KEY": f"TESTING_KEY",
         },
         },
     )
     )
 
 

+ 8 - 5
tests/model.py

@@ -2,7 +2,10 @@ import os
 import peewee
 import peewee
 import pytest
 import pytest
 
 
-os.environ["APP_PATH"] = "test"
+os.environ["LC_DB_PATH"] = ":memory:"
+os.environ["LC_SECRET_KEY"] = "TEST_KEY"
+os.environ["LC_APP_PATH"] = "localhost"
+
 import lc.config as c
 import lc.config as c
 import lc.error as e
 import lc.error as e
 import lc.request as r
 import lc.request as r
@@ -11,11 +14,11 @@ import lc.model as m
 
 
 class Testdb:
 class Testdb:
     def setup_method(self, _):
     def setup_method(self, _):
-        c.db.init(":memory:")
-        c.db.create_tables(m.MODELS)
+        c.app.in_memory_db()
+        m.create_tables()
 
 
     def teardown_method(self, _):
     def teardown_method(self, _):
-        c.db.close()
+        c.app.close_db()
 
 
     def mk_user(self, name="gdritter", password="foo") -> m.User:
     def mk_user(self, name="gdritter", password="foo") -> m.User:
         return m.User.from_request(r.User(name=name, password=password,))
         return m.User.from_request(r.User(name=name, password=password,))
@@ -118,7 +121,7 @@ class Testdb:
         assert invite.claimed_at is None
         assert invite.claimed_at is None
 
 
         # deserializing the unique token should reveal the encrypted data
         # deserializing the unique token should reveal the encrypted data
-        raw_data = c.serializer.loads(invite.token)
+        raw_data = c.app.load_token(invite.token)
         assert raw_data["created_by"] == u.name
         assert raw_data["created_by"] == u.name
 
 
     def test_use_invite(self):
     def test_use_invite(self):

+ 9 - 4
tests/routes.py

@@ -1,5 +1,10 @@
+import os
 import json
 import json
 
 
+os.environ["LC_DB_PATH"] = ":memory:"
+os.environ["LC_SECRET_KEY"] = "TEST_KEY"
+os.environ["LC_APP_PATH"] = "localhost"
+
 import lc.config as c
 import lc.config as c
 import lc.model as m
 import lc.model as m
 import lc.request as r
 import lc.request as r
@@ -8,12 +13,12 @@ import lc.app as a
 
 
 class TestRoutes:
 class TestRoutes:
     def setup_method(self, _):
     def setup_method(self, _):
-        c.db.init(":memory:")
-        c.db.create_tables(m.MODELS)
+        c.app.in_memory_db()
+        m.create_tables()
         self.app = a.app.test_client()
         self.app = a.app.test_client()
 
 
     def teardown_method(self, _):
     def teardown_method(self, _):
-        c.db.close()
+        c.app.close_db()
 
 
     def mk_user(self, username="gdritter", password="foo") -> m.User:
     def mk_user(self, username="gdritter", password="foo") -> m.User:
         return m.User.from_request(r.User(name=username, password=password,))
         return m.User.from_request(r.User(name=username, password=password,))
@@ -28,7 +33,7 @@ class TestRoutes:
         u = self.mk_user(username=username, password=password)
         u = self.mk_user(username=username, password=password)
         result = self.app.post("/auth", json={"name": username, "password": password})
         result = self.app.post("/auth", json={"name": username, "password": password})
         assert result.status == "200 OK"
         assert result.status == "200 OK"
-        decoded_token = c.serializer.loads(result.json["token"])
+        decoded_token = c.app.load_token(result.json["token"])
         assert decoded_token["name"] == username
         assert decoded_token["name"] == username
 
 
     def test_failed_api_login(self):
     def test_failed_api_login(self):