Coverage for rfpy/web/base.py: 99%

124 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-24 10:52 +0000

1import logging 

2from zlib import crc32 

3import importlib.resources 

4from contextlib import contextmanager 

5from typing import Dict, List, Union, Callable, Optional, TYPE_CHECKING 

6 

7import orjson 

8import webob.exc 

9from webob import Response 

10from webob.dec import wsgify 

11from pydantic.main import BaseModel 

12from semantic_version import Version # type: ignore[import] 

13 

14from rfpy.utils import json_default, benchmark 

15from rfpy import conf 

16from rfpy.conf.settings import RunMode 

17from rfpy.auth.policy import DevHeaderPolicy, AbstractIdentityPolicy, JwtBearerPolicy 

18from rfpy.templates import init_jinja, get_template 

19from rfpy.web.request import HttpRequest 

20from rfpy.web.exception import resolve_exception 

21 

22if TYPE_CHECKING: 

23 from rfpy.suxint import Sux 

24 

25 

26log = logging.getLogger(__name__) 

27 

28JS_ERR_MSG = "Server Error" 

29 

30API_VERSION_HTTP_HEADER = "X-RFPY-API-VERSION" 

31 

32# Cache version info from setup.py 

33v = importlib.resources.files("rfpy").joinpath("api_version").read_text().strip() 

34API_VERSION = Version(v) 

35 

36 

37def jsonify_models(api_output) -> Union[List, Dict]: 

38 if isinstance(api_output, (list, set)) and len(api_output) > 0: 

39 item = api_output[0] if isinstance(api_output, list) else api_output.pop() 

40 if isinstance(item, BaseModel): 

41 print("dumping %s models %s" % (len(api_output), item)) 

42 return [r.model_dump(by_alias=True) for r in api_output] 

43 return api_output 

44 

45 

46def render(request: HttpRequest, api_output): 

47 if api_output is None: 

48 json_bytes = b'{"result": "ok"}' 

49 else: 

50 if isinstance(api_output, BaseModel): 

51 json_bytes = api_output.model_dump_json(by_alias=True).encode("utf-8") 

52 else: 

53 json_data = jsonify_models(api_output) 

54 json_bytes = orjson.dumps(json_data, default=json_default) 

55 

56 if request.prefers_json: 

57 res = Response(json_bytes, charset="utf-8", content_type="application/json") 

58 else: 

59 template = get_template("api.html") 

60 html_output = template.render( 

61 js_doc=json_bytes.decode("utf-8"), url=request.path, request=request 

62 ) 

63 res = Response(html_output) 

64 

65 if getattr(request, "generate_etag", False): 

66 res.headers.add("Cache-Control", "must-revalidate") 

67 res.etag = str(crc32(json_bytes)) 

68 if res.etag in request.if_none_match: 

69 return webob.exc.HTTPNotModified(etag=res.etag) 

70 else: 

71 res.headers.add("Cache-Control", "no-cache") 

72 

73 return res 

74 

75 

76@contextmanager 

77def commit_or_rollback(session): 

78 try: 

79 yield 

80 except Exception: # nopep8 

81 session.rollback() 

82 raise 

83 else: 

84 session.commit() 

85 finally: 

86 session.close() 

87 

88 

89class WSGIApp(object): 

90 """ 

91 Entry point for WSGI commerce 

92 """ 

93 

94 routes: dict[str, Callable] = {} 

95 

96 def __init__( 

97 self, 

98 session_factory=None, 

99 auth_policy: AbstractIdentityPolicy | None = None, 

100 api_path="api", 

101 ): 

102 self.session_factory = session_factory 

103 self.sux_instance: Optional[Sux] = None 

104 self.api_path = api_path 

105 self.auth_policy = auth_policy 

106 

107 if auth_policy is None: 

108 if ( 

109 conf.CONF.run_mode is RunMode.test 

110 or conf.CONF.run_mode is RunMode.development 

111 ): 

112 self.auth_policy = DevHeaderPolicy() 

113 else: 

114 self.auth_policy = JwtBearerPolicy() 

115 else: 

116 if isinstance(auth_policy, type): 

117 self.auth_policy = auth_policy() 

118 elif not isinstance(auth_policy, AbstractIdentityPolicy): 

119 raise TypeError("auth_policy must inherit from AbstractIdenityPolicy") 

120 

121 init_jinja() 

122 self.build_sux() 

123 

124 log.info( 

125 "%s App initialised. Auth: %s. API Version %s", 

126 self.__class__.__name__, 

127 self.auth_policy.__class__.__name__, 

128 API_VERSION, 

129 ) 

130 

131 def build_sux(self): # pragma: no cover 

132 """ 

133 If a subclass wants to serve a suxint.Sux API then it must 

134 implement this method to assign a value to self.sux_instance 

135 """ 

136 raise NotImplementedError 

137 

138 @wsgify(RequestClass=HttpRequest) 

139 def __call__(self, request): 

140 try: 

141 handler = self.resolve_route(request) 

142 

143 request.session = session = self.session_factory() 

144 self.authenticate(request) 

145 

146 with commit_or_rollback(session): 

147 response = handler(request) 

148 

149 self.auth_policy.remember(request, response) 

150 

151 return response 

152 

153 except Exception as e: 

154 if conf.CONF.run_mode is RunMode.development: 

155 log.exception( 

156 "Exception in Base webapp, RunMode.development so raising.." 

157 ) 

158 raise 

159 else: 

160 # Set request user to None to avoid detached sqla session 

161 # errors caused by User object lurking in environ dict 

162 request.user = None 

163 return resolve_exception(request, e) 

164 

165 def authenticate(self, request): 

166 self.auth_policy.identify(request) 

167 self.validate_user(request) 

168 

169 def resolve_route(self, request): 

170 sub_app = request.path_info_peek() or "" 

171 

172 if sub_app == self.api_path: 

173 return self.rest_api 

174 

175 elif sub_app in self.routes: 

176 return self.routes[sub_app] 

177 else: 

178 log.warning("No handler found for sub_app: %s", sub_app) 

179 raise webob.exc.HTTPNotFound 

180 

181 def rest_api(self, request): 

182 path_info = request.path_info 

183 with benchmark("API call to %s %s" % (request.method, path_info)): 

184 api_output = self.sux_instance(request) 

185 

186 if isinstance(api_output, Response): 

187 response = api_output 

188 else: 

189 response = render(request, api_output) 

190 

191 response.headers.add(API_VERSION_HTTP_HEADER, str(API_VERSION)) 

192 

193 return response 

194 

195 @classmethod 

196 def route(cls, url_path): 

197 """Provides a decorator method for handler functions to register 

198 at the given URL path 

199 """ 

200 

201 def wrapper(handler_function): 

202 base_path = url_path.lstrip("/") 

203 if base_path in cls.routes: 

204 existing_handler = cls.routes[base_path] 

205 args = (base_path, existing_handler, cls) 

206 raise ValueError("%s path already taken by %s in %s" % args) 

207 cls.routes[url_path.lstrip("/")] = handler_function 

208 return handler_function 

209 

210 return wrapper 

211 

212 def __repr__(self): 

213 return "App - Base WSGI application" 

214 

215 def validate_user(self, request): # pragma: no-cover 

216 raise NotImplementedError("Subclasses to implement")