Browse Source

Refactor remaining uses of config

Getty Ritter 3 years ago
parent
commit
d1433a236e
10 changed files with 44 additions and 28 deletions
  1. 1 1
      lament-configuration.py
  2. 1 1
      lc/app.py
  3. 13 7
      lc/model.py
  4. 1 1
      lc/request.py
  5. 2 2
      lc/view.py
  6. 6 5
      lc/web.py
  7. 0 1
      scripts/populate.py
  8. 3 1
      tasks.py
  9. 8 5
      tests/model.py
  10. 9 4
      tests/routes.py

+ 1 - 1
lament-configuration.py

@@ -4,5 +4,5 @@ from lc.app import app
 import lc.config
 import lc.model
 
-lc.config.db.init(os.getenv("DB_LOC", "test.db"))
+lc.config.app.init_db()
 lc.model.create_tables()

+ 1 - 1
lc/app.py

@@ -10,7 +10,7 @@ import lc.request as r
 import lc.view as v
 from lc.web import Endpoint, endpoint, render
 
-app = c.app
+app = c.app.app
 
 
 @endpoint("/")

+ 13 - 7
lc/model.py

@@ -1,3 +1,4 @@
+from contextlib import contextmanager
 from dataclasses import dataclass
 import datetime
 import json
@@ -14,11 +15,16 @@ import lc.view as v
 
 class Model(peewee.Model):
     class Meta:
-        database = c.db
+        database = c.app.db
 
     def to_dict(self) -> dict:
         return playhouse.shortcuts.model_to_dict(self)
 
+    @contextmanager
+    def atomic(self):
+        with c.app.db.atomic():
+            yield
+
 
 class User(Model):
     """
@@ -87,7 +93,7 @@ class User(Model):
         query = Link.select().where(
             (Link.user == self) & ((self == as_user) | (Link.private == False))
         )
-        links = query.order_by(-Link.created).paginate(page, c.per_page)
+        links = query.order_by(-Link.created).paginate(page, c.app.per_page)
         link_views = [l.to_view(as_user) for l in links]
         pagination = v.Pagination.from_total(page, query.count())
         return link_views, pagination
@@ -138,7 +144,7 @@ class User(Model):
 
                 tags[t] = Tag.get_or_create_tag(self, t)
 
-        with c.db.atomic():
+        with self.atomic():
             for l in links:
                 try:
                     time = datetime.datetime.strptime(l["time"], "%Y-%m-%dT%H:%M:%SZ")
@@ -212,7 +218,7 @@ class Link(Model):
         return l
 
     def update_from_request(self, user: User, link: r.Link):
-        with c.db.atomic():
+        with self.atomic():
             req_tags = set(link.tags)
 
             for hastag in self.tags:  # type: ignore
@@ -277,7 +283,7 @@ class Tag(Model):
         )
         links = [
             ht.link.to_view(as_user)
-            for ht in query.order_by(-Link.created).paginate(page, c.per_page)
+            for ht in query.order_by(-Link.created).paginate(page, c.app.per_page)
         ]
         pagination = v.Pagination.from_total(page, query.count())
         return links, pagination
@@ -351,7 +357,7 @@ class UserInvite(Model):
     @staticmethod
     def manufacture(creator: User) -> "UserInvite":
         now = datetime.datetime.now()
-        token = c.serializer.dumps(
+        token = c.app.serialize_token(
             {"created_at": now.timestamp(), "created_by": creator.name,}
         )
         return UserInvite.create(
@@ -373,4 +379,4 @@ MODELS = [
 
 
 def create_tables():
-    c.db.create_tables(MODELS, safe=True)
+    c.app.db.create_tables(MODELS, safe=True)

+ 1 - 1
lc/request.py

@@ -35,7 +35,7 @@ class User(Request):
         return cls(name=form["username"], password=form["password"],)
 
     def to_token(self) -> str:
-        return c.serializer.dumps({"name": self.name})
+        return c.app.serialize_token({"name": self.name})
 
 
 @dataclass_json

+ 2 - 2
lc/view.py

@@ -26,7 +26,7 @@ class Pagination(View):
 
     @classmethod
     def from_total(cls, current, total) -> "Pagination":
-        return cls(current=current, last=((total - 1) // c.per_page) + 1,)
+        return cls(current=current, last=((total - 1) // c.app.per_page) + 1,)
 
 
 @dataclass
@@ -55,7 +55,7 @@ class Config(View):
     def bookmarklet_link(self):
         return (
             "javascript:(function(){window.open(`"
-            + c.app_path
+            + c.app.config.app_path
             + "/u/"
             + self.username
             + "/l?name=${document.title}&url=${document.URL}`);})();"

+ 6 - 5
lc/web.py

@@ -19,7 +19,8 @@ class ApiOK:
 
 
 class Endpoint:
-    __slots__ = ('user',)
+    __slots__ = ("user",)
+
     def __init__(self):
         self.user = None
 
@@ -38,7 +39,7 @@ class Endpoint:
         # if that exists and we can deserialize it, then make sure
         # it contains a valid user password, too
         try:
-            payload = c.serializer.loads(token)
+            payload = c.app.load_token(token)
         except:
             # TODO: be more specific about what errors we're catching
             # here!
@@ -179,7 +180,7 @@ def endpoint(route: str):
         func.__name__ = endpoint_class.__name__
 
         # finally, use the Flask routing machinery to register our callback
-        return c.app.route(route, methods=methods)(func)
+        return c.app.app.route(route, methods=methods)(func)
 
     return do_endpoint
 
@@ -194,7 +195,7 @@ def render(name: str, data: Optional[v.View] = None) -> str:
     return renderer.render(template, data or {})
 
 
-@c.app.errorhandler(404)
+@c.app.app.errorhandler(404)
 def handle_404(e):
     user = Endpoint.just_get_user()
     url = flask.request.path
@@ -203,7 +204,7 @@ def handle_404(e):
     return render("main", page)
 
 
-@c.app.errorhandler(500)
+@c.app.app.errorhandler(500)
 def handle_500(e):
     user = Endpoint.just_get_user()
     c.log(f"Internal error: {e}")

+ 0 - 1
scripts/populate.py

@@ -10,7 +10,6 @@ import lc.request as r
 
 
 def main():
-    c.db.init("test.db")
     m.create_tables()
 
     u = m.User.get_or_none(name="gdritter")

+ 3 - 1
tasks.py

@@ -15,7 +15,9 @@ def run(c, port=8080, host="127.0.0.1"):
         f"poetry run python -m flask run -p {port} -h {host}",
         env={
             "FLASK_APP": "lament-configuration.py",
-            "APP_PATH": f"http://{host}:{port}",
+            "LC_APP_PATH": f"http://{host}:{port}",
+            "LC_DB_PATH": f"test.db",
+            "LC_SECRET_KEY": f"TESTING_KEY",
         },
     )
 

+ 8 - 5
tests/model.py

@@ -2,7 +2,10 @@ import os
 import peewee
 import pytest
 
-os.environ["APP_PATH"] = "test"
+os.environ["LC_DB_PATH"] = ":memory:"
+os.environ["LC_SECRET_KEY"] = "TEST_KEY"
+os.environ["LC_APP_PATH"] = "localhost"
+
 import lc.config as c
 import lc.error as e
 import lc.request as r
@@ -11,11 +14,11 @@ import lc.model as m
 
 class Testdb:
     def setup_method(self, _):
-        c.db.init(":memory:")
-        c.db.create_tables(m.MODELS)
+        c.app.in_memory_db()
+        m.create_tables()
 
     def teardown_method(self, _):
-        c.db.close()
+        c.app.close_db()
 
     def mk_user(self, name="gdritter", password="foo") -> m.User:
         return m.User.from_request(r.User(name=name, password=password,))
@@ -118,7 +121,7 @@ class Testdb:
         assert invite.claimed_at is None
 
         # deserializing the unique token should reveal the encrypted data
-        raw_data = c.serializer.loads(invite.token)
+        raw_data = c.app.load_token(invite.token)
         assert raw_data["created_by"] == u.name
 
     def test_use_invite(self):

+ 9 - 4
tests/routes.py

@@ -1,5 +1,10 @@
+import os
 import json
 
+os.environ["LC_DB_PATH"] = ":memory:"
+os.environ["LC_SECRET_KEY"] = "TEST_KEY"
+os.environ["LC_APP_PATH"] = "localhost"
+
 import lc.config as c
 import lc.model as m
 import lc.request as r
@@ -8,12 +13,12 @@ import lc.app as a
 
 class TestRoutes:
     def setup_method(self, _):
-        c.db.init(":memory:")
-        c.db.create_tables(m.MODELS)
+        c.app.in_memory_db()
+        m.create_tables()
         self.app = a.app.test_client()
 
     def teardown_method(self, _):
-        c.db.close()
+        c.app.close_db()
 
     def mk_user(self, username="gdritter", password="foo") -> m.User:
         return m.User.from_request(r.User(name=username, password=password,))
@@ -28,7 +33,7 @@ class TestRoutes:
         u = self.mk_user(username=username, password=password)
         result = self.app.post("/auth", json={"name": username, "password": password})
         assert result.status == "200 OK"
-        decoded_token = c.serializer.loads(result.json["token"])
+        decoded_token = c.app.load_token(result.json["token"])
         assert decoded_token["name"] == username
 
     def test_failed_api_login(self):