Coverage for rfpy/web/base.py: 99%
124 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-24 10:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-24 10:52 +0000
1import logging
2from zlib import crc32
3import importlib.resources
4from contextlib import contextmanager
5from typing import Dict, List, Union, Callable, Optional, TYPE_CHECKING
7import orjson
8import webob.exc
9from webob import Response
10from webob.dec import wsgify
11from pydantic.main import BaseModel
12from semantic_version import Version # type: ignore[import]
14from rfpy.utils import json_default, benchmark
15from rfpy import conf
16from rfpy.conf.settings import RunMode
17from rfpy.auth.policy import DevHeaderPolicy, AbstractIdentityPolicy, JwtBearerPolicy
18from rfpy.templates import init_jinja, get_template
19from rfpy.web.request import HttpRequest
20from rfpy.web.exception import resolve_exception
22if TYPE_CHECKING:
23 from rfpy.suxint import Sux
26log = logging.getLogger(__name__)
28JS_ERR_MSG = "Server Error"
30API_VERSION_HTTP_HEADER = "X-RFPY-API-VERSION"
32# Cache version info from setup.py
33v = importlib.resources.files("rfpy").joinpath("api_version").read_text().strip()
34API_VERSION = Version(v)
37def jsonify_models(api_output) -> Union[List, Dict]:
38 if isinstance(api_output, (list, set)) and len(api_output) > 0:
39 item = api_output[0] if isinstance(api_output, list) else api_output.pop()
40 if isinstance(item, BaseModel):
41 print("dumping %s models %s" % (len(api_output), item))
42 return [r.model_dump(by_alias=True) for r in api_output]
43 return api_output
46def render(request: HttpRequest, api_output):
47 if api_output is None:
48 json_bytes = b'{"result": "ok"}'
49 else:
50 if isinstance(api_output, BaseModel):
51 json_bytes = api_output.model_dump_json(by_alias=True).encode("utf-8")
52 else:
53 json_data = jsonify_models(api_output)
54 json_bytes = orjson.dumps(json_data, default=json_default)
56 if request.prefers_json:
57 res = Response(json_bytes, charset="utf-8", content_type="application/json")
58 else:
59 template = get_template("api.html")
60 html_output = template.render(
61 js_doc=json_bytes.decode("utf-8"), url=request.path, request=request
62 )
63 res = Response(html_output)
65 if getattr(request, "generate_etag", False):
66 res.headers.add("Cache-Control", "must-revalidate")
67 res.etag = str(crc32(json_bytes))
68 if res.etag in request.if_none_match:
69 return webob.exc.HTTPNotModified(etag=res.etag)
70 else:
71 res.headers.add("Cache-Control", "no-cache")
73 return res
76@contextmanager
77def commit_or_rollback(session):
78 try:
79 yield
80 except Exception: # nopep8
81 session.rollback()
82 raise
83 else:
84 session.commit()
85 finally:
86 session.close()
89class WSGIApp(object):
90 """
91 Entry point for WSGI commerce
92 """
94 routes: dict[str, Callable] = {}
96 def __init__(
97 self,
98 session_factory=None,
99 auth_policy: AbstractIdentityPolicy | None = None,
100 api_path="api",
101 ):
102 self.session_factory = session_factory
103 self.sux_instance: Optional[Sux] = None
104 self.api_path = api_path
105 self.auth_policy = auth_policy
107 if auth_policy is None:
108 if (
109 conf.CONF.run_mode is RunMode.test
110 or conf.CONF.run_mode is RunMode.development
111 ):
112 self.auth_policy = DevHeaderPolicy()
113 else:
114 self.auth_policy = JwtBearerPolicy()
115 else:
116 if isinstance(auth_policy, type):
117 self.auth_policy = auth_policy()
118 elif not isinstance(auth_policy, AbstractIdentityPolicy):
119 raise TypeError("auth_policy must inherit from AbstractIdenityPolicy")
121 init_jinja()
122 self.build_sux()
124 log.info(
125 "%s App initialised. Auth: %s. API Version %s",
126 self.__class__.__name__,
127 self.auth_policy.__class__.__name__,
128 API_VERSION,
129 )
131 def build_sux(self): # pragma: no cover
132 """
133 If a subclass wants to serve a suxint.Sux API then it must
134 implement this method to assign a value to self.sux_instance
135 """
136 raise NotImplementedError
138 @wsgify(RequestClass=HttpRequest)
139 def __call__(self, request):
140 try:
141 handler = self.resolve_route(request)
143 request.session = session = self.session_factory()
144 self.authenticate(request)
146 with commit_or_rollback(session):
147 response = handler(request)
149 self.auth_policy.remember(request, response)
151 return response
153 except Exception as e:
154 if conf.CONF.run_mode is RunMode.development:
155 log.exception(
156 "Exception in Base webapp, RunMode.development so raising.."
157 )
158 raise
159 else:
160 # Set request user to None to avoid detached sqla session
161 # errors caused by User object lurking in environ dict
162 request.user = None
163 return resolve_exception(request, e)
165 def authenticate(self, request):
166 self.auth_policy.identify(request)
167 self.validate_user(request)
169 def resolve_route(self, request):
170 sub_app = request.path_info_peek() or ""
172 if sub_app == self.api_path:
173 return self.rest_api
175 elif sub_app in self.routes:
176 return self.routes[sub_app]
177 else:
178 log.warning("No handler found for sub_app: %s", sub_app)
179 raise webob.exc.HTTPNotFound
181 def rest_api(self, request):
182 path_info = request.path_info
183 with benchmark("API call to %s %s" % (request.method, path_info)):
184 api_output = self.sux_instance(request)
186 if isinstance(api_output, Response):
187 response = api_output
188 else:
189 response = render(request, api_output)
191 response.headers.add(API_VERSION_HTTP_HEADER, str(API_VERSION))
193 return response
195 @classmethod
196 def route(cls, url_path):
197 """Provides a decorator method for handler functions to register
198 at the given URL path
199 """
201 def wrapper(handler_function):
202 base_path = url_path.lstrip("/")
203 if base_path in cls.routes:
204 existing_handler = cls.routes[base_path]
205 args = (base_path, existing_handler, cls)
206 raise ValueError("%s path already taken by %s in %s" % args)
207 cls.routes[url_path.lstrip("/")] = handler_function
208 return handler_function
210 return wrapper
212 def __repr__(self):
213 return "App - Base WSGI application"
215 def validate_user(self, request): # pragma: no-cover
216 raise NotImplementedError("Subclasses to implement")