Browse Source

Merge branch 'gdritter/tag-fixes' of getty/lament-configuration into master

getty 4 years ago
parent
commit
ff9835a11e
3 changed files with 45 additions and 6 deletions
  1. 11 0
      lc/error.py
  2. 20 6
      lc/model.py
  3. 14 0
      tests/model.py

+ 11 - 0
lc/error.py

@@ -127,3 +127,14 @@ class MismatchedPassword(LCException):
 
     def http_code(self) -> int:
         return 400
+
+
+@dataclass
+class BadTagName(LCException):
+    tag_name: str
+
+    def __str__(self):
+        return f"'{self.tag_name}' is not a valid tag name, for Reasons."
+
+    def http_code(self) -> int:
+        return 400

+ 20 - 6
lc/model.py

@@ -3,7 +3,7 @@ import datetime
 from passlib.apps import custom_app_context as pwd
 import peewee
 import playhouse.shortcuts
-from typing import List, Optional, Tuple
+from typing import Iterator, List, Optional, Tuple
 
 import lc.config as c
 import lc.error as e
@@ -144,10 +144,9 @@ class Link(Model):
             user=user,
         )
         for tag_name in link.tags:
-            t = Tag.get_or_create_tag(user, tag_name)
-            HasTag.create(
-                link=l, tag=t,
-            )
+            tag = Tag.get_or_create_tag(user, tag_name)
+            for t in tag.get_family():
+                HasTag.get_or_create(link=l, tag=t)
         return l
 
     def update_from_request(self, user: User, link: r.Link):
@@ -222,11 +221,26 @@ class Tag(Model):
         )
         return links, pagination
 
+    def get_family(self) -> Iterator["Tag"]:
+        yield self
+        p = self
+        while (p := p.parent) :
+            yield p
+
+    BAD_TAG_CHARS = set("{}[]\\()#?")
+
     @staticmethod
-    def get_or_create_tag(user: User, tag_name: str):
+    def is_valid_tag_name(tag_name: str) -> bool:
+        return all((c not in Tag.BAD_TAG_CHARS for c in tag_name))
+
+    @staticmethod
+    def get_or_create_tag(user: User, tag_name: str) -> "Tag":
         if (t := Tag.get_or_none(name=tag_name, user=user)) :
             return t
 
+        if not Tag.is_valid_tag_name(tag_name):
+            raise e.BadTagName(tag_name)
+
         parent = None
         if "/" in tag_name:
             parent_name = tag_name[: tag_name.rindex("/")]

+ 14 - 0
tests/model.py

@@ -93,6 +93,20 @@ class Testdb:
         assert t.id == m.Tag.get(name="food/bread/rye").id
         assert t2.id == m.Tag.get(name="food/bread/baguette").id
 
+    def test_add_hierarchy(self):
+        u = self.mk_user()
+        req = r.Link("http://foo.com", "foo", "", False, ["food/bread/rye"])
+        l = m.Link.from_request(u, req)
+        assert l.name == req.name
+        tag_names = {t.tag.name for t in l.tags}  # type: ignore
+        assert tag_names == {"food", "food/bread", "food/bread/rye"}
+
+    def test_bad_tag(self):
+        u = self.mk_user()
+        req = r.Link("http://foo.com", "foo", "", False, ["foo{bar}"])
+        with pytest.raises(e.BadTagName):
+            l = m.Link.from_request(u, req)
+
     def test_create_invite(self):
         u = self.mk_user()
         invite = m.UserInvite.manufacture(u)