Refactor redis client and app factory to their own files
authorJacob <jobs@jacobcasper.com>
Sun, 7 Sep 2025 20:06:09 +0000 (15:06 -0500)
committerJacob <jobs@jacobcasper.com>
Sun, 7 Sep 2025 20:06:09 +0000 (15:06 -0500)
app.py [new file with mode: 0644]
main.py
redis_client.py [new file with mode: 0644]

diff --git a/app.py b/app.py
new file mode 100644 (file)
index 0000000..1b72227
--- /dev/null
+++ b/app.py
@@ -0,0 +1,98 @@
+import json
+import logging
+import random
+import string
+import uuid
+
+from typing import Optional
+
+from redis_client import get_client
+
+from flask import Flask, Response, request
+
+
+def create_app():
+    logging.basicConfig(level=logging.DEBUG)
+    app = Flask(__name__)
+
+    redis_client = get_client()
+
+    @app.route("/stats/", methods=["GET"])
+    def stats():
+        redis_client.increment_hits("stats")
+        try:
+            stats = redis_client.get_stats()
+            return {"stats": [{stat[0]: stat[1]} for stat in stats][::-1]}
+        except redis.exceptions.ConnectionError as e:
+            app.logger.exception(e)
+            return Response("Internal Server Error", 500)
+        except Exception as e:
+            app.logger.exception(e)
+            return Response("Internal Server Error", 500)
+
+    @app.route("/api/", methods=["GET"])
+    @app.route("/api/<path:subpath>", methods=["GET"])
+    def api(subpath: Optional[str] = None):
+        app.logger.debug("Handling subpath %s", subpath)
+        # ASSUMPTION: As the app code is actually invoked, this path is being "handled".
+        # If paths failing validation needed to not be incremented then I would move this logic after validation.
+        hits = redis_client.increment_hits(subpath)
+        if subpath is None or len(path_parts := subpath.split("/")) > 6:
+            return Response(
+                response=json.dumps(
+                    {
+                        "error": f"Invalid API path ({subpath}). Path must contain 1 to 6 segments."
+                    }
+                ),
+                status=400,
+            )
+        # ASSUMPTION: The test runner will provide the api context to retrieve valid path components from the database
+        test_id = request.args.get("test")
+        if test_id:
+            app.logger.debug("Under test %s", test_id)
+            if test_parts := redis_client.get_test(test_id):
+                if any([part not in test_parts for part in path_parts]):
+                    return Response(
+                        response=json.dumps(
+                            {
+                                "error": f"Invalid API path ({subpath}). All segments in test {test_id} must be in {test_parts}"
+                            }
+                        ),
+                        status=400,
+                    )
+            else:
+                app.logger.debug("Test '%s' not found", test_id)
+        return {"hits": hits}
+
+    @app.route("/test/<int:num_requests>/", methods=["POST"])
+    def test(num_requests: int):
+        redis_client.increment_hits("test")
+        # ASSUMPTION: The requirements don't specify if the test runner or the api path under test specify
+        # which 3 path parts are valid, so I will have the test runner generate and store them.
+        test_id = str(uuid.uuid4())
+
+        # ASSUMPTION: The valid test paths may be ascii lowercase.
+        # ASSUMPTION: Test paths selected as short 3 character strings for ease of manual entry.
+        paths = [
+            "".join([random.choice(string.ascii_lowercase) for i in range(3)])
+            for j in range(3)
+        ]
+
+        redis_client.create_test(test_id, paths)
+
+        # ASSUMPTION: The test may be run in the context of the test creation, without using a background worker.
+        try:
+            for req_index in range(num_requests):
+                with app.test_request_context(query_string={"test": test_id}):
+                    api(
+                        "/".join(
+                            [random.choice(paths) for i in range(random.randint(1, 6))]
+                        )
+                    )
+        except Exception as e:
+            # Report, swallow, and return test ID to user for use in reviewing manual testing behavior.
+            app.logger.exception(e)
+
+        return {"test_id": test_id, "test_paths": paths}
+
+    return app
diff --git a/main.py b/main.py
index 4ccb1f9739074861db3e3badcf1a3f4245cdf108..7212ef5596b450a3bc1ef02fe19c632852e863d5 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,105 +1,6 @@
-import json
-import logging
-import random
-import uuid
+from app import create_app
 
-from typing import Optional
-
-import redis
-from flask import Flask, Response, request
-
-redis_client = redis.Redis(host="redis", port=6379, decode_responses=True)
-
-logging.basicConfig(level=logging.DEBUG)
-app = Flask(__name__)
-
-
-def increment_hits(url: str) -> int:
-    return redis_client.zincrby(f"hits", 1, url)
-
-
-@app.route("/stats/", methods=["GET"])
-def stats():
-    increment_hits("stats")
-    try:
-        stats = redis_client.zscan_iter("hits")
-        if stats is None:
-            app.logger.debug("Uninitialized stats object")
-            return {}
-        return {"stats": [{stat[0]: stat[1]} for stat in stats][::-1]}
-    except redis.exceptions.ConnectionError as e:
-        app.logger.exception(e)
-        return Response("Internal Server Error", 500)
-    except Exception as e:
-        app.logger.exception(e)
-        return "Internal Server Error", 500
-    return {}
-
-
-@app.route("/api/", methods=["GET"])
-@app.route("/api/<path:subpath>", methods=["GET"])
-def api(subpath: Optional[str] = None):
-    app.logger.debug("Handling subpath %s", subpath)
-    # ASSUMPTION: As the app code is actually invoked, this path is being "handled".
-    # If paths failing validation needed to not be incremented then I would move this logic after validation.
-    hits = increment_hits(subpath)
-    if subpath is None or len(path_parts := subpath.split("/")) > 6:
-        return Response(
-            response=json.dumps(
-                {
-                    "error": f"Invalid API path ({subpath}). Path must contain 1 to 6 segments."
-                }
-            ),
-            status=400,
-        )
-    if any([len(part) != 3 for part in path_parts]):
-        return Response(
-            response=json.dumps(
-                {
-                    "error": f"Invalid API path ({subpath}). All segments must be 3 characters long."
-                }
-            ),
-            status=400,
-        )
-    # ASSUMPTION: The test runner will provide the api context to retrieve valid path components from the database
-    test_id = request.args.get("test")
-    if test_id:
-        app.logger.debug("Under test %s", test_id)
-        if test_parts := redis_client.smembers(f"test:{test_id}"):
-            if any([part not in test_parts for part in path_parts]):
-                return Response(
-                    response=json.dumps(
-                        {
-                            "error": f"Invalid API path ({subpath}). All segments in test {test_id} must be in {test_parts}"
-                        }
-                    ),
-                    status=400,
-                )
-        else:
-            app.logger.debug("Test '%s' not found", test_id)
-    return {"hits": hits}
-
-
-@app.route("/test/<int:num_requests>/", methods=["POST"])
-def test(num_requests: int):
-    increment_hits("test")
-    # ASSUMPTION: The requirements don't specify if the test runner or the api path under test specify
-    # which 3 path parts are valid, so I will have the test runner generate and store them.
-    test_id = str(uuid.uuid4())
-    paths = ["abc", "def", "ghi"]
-    redis_client.sadd(f"test:{test_id}", *paths)
-
-    # ASSUMPTION: The test may be run in the context of the test creation, without using a background worker.
-    try:
-        for req_index in range(num_requests):
-            with app.test_request_context(query_string={"test": test_id}):
-                api(
-                    "/".join(
-                        [random.choice(paths) for i in range(random.randint(1, 6))]
-                    )
-                )
-    except Exception as e:
-        # Report, swallow, and return test ID to user for use in reviewing manual testing behavior.
-        app.logger.exception(e)
-
-    return {"test_id": test_id}
+if __name__ == "__main__":
+    redis = get_client()
+    app = create_app()
+    app.run()
diff --git a/redis_client.py b/redis_client.py
new file mode 100644 (file)
index 0000000..576b993
--- /dev/null
@@ -0,0 +1,39 @@
+import logging
+
+import redis
+
+log = logging.getLogger(__name__)
+
+_client = None
+
+
+def get_client():
+    global _client
+    if not _client:
+        _client = RedisClient()
+    return _client
+
+
+class RedisClient:
+    def __init__(self):
+        self.r = redis.Redis(host="redis", port=6379, decode_responses=True)
+
+    def increment_hits(self, url: str) -> int:
+        return self.r.zincrby(f"hits", 1, url)
+
+    def get_stats(self) -> list[tuple]:
+        """Returns stats values from a Redis sorted set as a tuple of key-value pairs."""
+        stats = self.r.zscan_iter("hits")
+        if stats is None:
+            logger.debug("Uninitialized stats object")
+            return []
+        return stats
+
+    def create_test(self, test_id: str, paths: list[str]):
+        return self.r.sadd(f"test:{test_id}", *paths)
+
+    def get_test(self, test_id: str) -> set[str]:
+        test_parts = self.r.smembers(f"test:{test_id}")
+        if not test_parts:
+            return set()
+        return test_parts