Coverage for yasmon/tasks.py: 97%

184 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-28 10:57 +0000

1from yasmon.callbacks import AbstractCallback 

2from yasmon.callbacks import CallbackAttributeError 

3from yasmon.callbacks import CallbackCircularAttributeError 

4 

5from loguru import logger 

6from abc import ABC, abstractmethod 

7from typing import Self, Optional 

8import watchfiles 

9import asyncio 

10import signal 

11import yaml 

12import pathlib 

13 

14 

15class TaskSyntaxError(Exception): 

16 """ 

17 Raised on task syntax issue. 

18 """ 

19 

20 def __init__(self, message="task syntax error"): 

21 self.message = message 

22 super().__init__(self.message) 

23 

24 

25class TaskNotImplementedError(Exception): 

26 """ 

27 Raised if a requested task type is not implemented. 

28 """ 

29 

30 def __init__(self, type, message="task {type} not implemented"): 

31 self.message = message.format(type=type) 

32 super().__init__(self.message) 

33 

34 

35class AbstractTask(ABC): 

36 """ 

37 Abstract class from which all task classes are derived. 

38 

39 Derived tasks are functors calling the assigned callback coroutine 

40 and can be used for :class:`yasmon.tasks.TaskRunner`. 

41 

42 The preferred way to instatiate a task is from class 

43 method :func:`~from_yaml`. 

44 """ 

45 @abstractmethod 

46 def __init__(self): 

47 if not self.name: 

48 self.name = "Generic Task" 

49 if not self.attrs: 

50 self.attrs = {} 

51 logger.info(f'{self.name} ({self.__class__}) initialized') 

52 

53 @abstractmethod 

54 async def __call__(self, callback: AbstractCallback): 

55 """ 

56 Coroutine called by :class:`TaskRunner`. 

57 """ 

58 logger.info(f'{self.name} ({self.__class__}) scheduled with ' 

59 f'{callback.name} ({callback.__class__})') 

60 

61 @classmethod 

62 @abstractmethod 

63 def from_yaml(cls, name: str, data: str, 

64 callbacks: list[AbstractCallback]): 

65 """ 

66 A class method for constructing a callback from a YAML document. 

67 

68 :param name: unique identifier 

69 :param data: yaml data 

70 :param callbacks: collection of callbacks 

71 

72 :return: new instance 

73 """ 

74 logger.debug(f'{name} defined from yaml \n{data}') 

75 

76 

77class TaskList(list): 

78 """ 

79 A dedicated `list` for tasks. 

80 """ 

81 def __init__(self, iterable: Optional[list[AbstractTask]] = None): 

82 if iterable is not None: 

83 super().__init__(item for item in iterable) 

84 else: 

85 super().__init__() 

86 

87 

88class TaskError(Exception): 

89 """ 

90 Raised when a path watched by task does not exist anymore. 

91 """ 

92 

93 def __init__(self, task, message="error in task {task}"): 

94 self.message = message.format(task=task) 

95 super().__init__(self.message) 

96 

97 

98class WatchfilesTask(AbstractTask): 

99 def __init__(self, name: str, changes: list[watchfiles.Change], 

100 callbacks: list[AbstractCallback], paths: list[str], 

101 timeout: int, max_retry: int, 

102 attrs: Optional[dict[str, str]] = None) -> None: 

103 """ 

104 :param name: unique identifier 

105 :param changes: list of watchfiles events 

106 :param callbacks: assigned callbacks 

107 :param paths: paths to watch (files/directories) 

108 :param attrs: (static) attributes 

109 """ 

110 self.name = name 

111 self.changes = changes 

112 self.callbacks = callbacks 

113 self.paths = paths 

114 self.abs_paths = [] # resolved upon self.__call__() 

115 self.attrs = {} if attrs is None else attrs 

116 self.max_retry = max_retry 

117 self.timeout = timeout 

118 super().__init__() 

119 

120 async def __call__(self, callback): 

121 await super().__call__(callback) 

122 retry = 0 

123 max_retry = 'inf' if self.max_retry == -1 else self.max_retry 

124 

125 while True: 

126 try: 

127 for path in self.paths: 

128 if not pathlib.Path.exists(pathlib.Path(path)): 

129 raise FileNotFoundError(path) 

130 else: 

131 abs_path = pathlib.Path(path).resolve() 

132 self.abs_paths.append(str(abs_path)) 

133 except FileNotFoundError as path: 

134 if retry == self.max_retry and self.max_retry > -1: 

135 logger.error(f'in task {self.name}' 

136 f' path {path} does not exist ' 

137 f'and max_retry {self.max_retry} ' 

138 'was reached') 

139 raise TaskError(self.name) 

140 else: 

141 retry += 1 

142 logger.warning(f'in task {self.name}' 

143 f' path {path} does not exist (anymore)' 

144 f'... retrying callback {callback.name}' 

145 f' after {self.timeout} sec timeout ' 

146 f' ({retry}/{max_retry} retries)') 

147 await asyncio.sleep(self.timeout) 

148 continue 

149 

150 # run awatch loop if all paths exist 

151 try: 

152 await self.awatch_loop(callback) 

153 except FileNotFoundError: 

154 continue 

155 finally: 

156 retry = 0 

157 

158 async def awatch_loop(self, callback): 

159 async for changes in watchfiles.awatch(*self.paths): 

160 for (change, path) in changes: 

161 if change in self.changes: 

162 match change: 

163 case watchfiles.Change.added: 

164 chng = 'added' 

165 case watchfiles.Change.modified: 

166 chng = 'modified' 

167 case watchfiles.Change.deleted: 

168 chng = 'deleted' 

169 try: 

170 call_attrs = {'change': chng, 'path': path} 

171 await callback(self, self.attrs | call_attrs) 

172 except CallbackAttributeError as err: 

173 logger.error(f'in task {self.name} callback {callback.name} raised {err}') # noqa 

174 raise err 

175 except CallbackCircularAttributeError as err: 

176 logger.error(f'in task {self.name} callback {callback.name} raised {err}') # noqa 

177 raise err 

178 

179 if (change == watchfiles.Change.deleted and 

180 path in self.abs_paths): 

181 raise FileNotFoundError 

182 

183 @classmethod 

184 def from_yaml(cls, name: str, data: str, 

185 callbacks: list[AbstractCallback]) -> Self: 

186 """ 

187 :class:`WatchfilesTask` can be also constructed from a YAML snippet. 

188 

189 .. code:: yaml 

190 

191 changes: 

192 - added 

193 callbacks: 

194 - callback0 

195 paths: 

196 - /some/path/to/file1 

197 - /some/path/to/file2 

198 attrs: 

199 myattr: value 

200 timeout: 10 

201 max_retry: 3 

202 

203 Possible changes are ``added``, ``modified`` and ``deleted``. 

204 

205 :param name: unique identifier 

206 :param data: YAML snippet 

207 :param callbacks: list of associated callbacks 

208 

209 :return: new instance 

210 :rtype: WatchfilesTask 

211 """ 

212 super().from_yaml(name, data, callbacks) 

213 try: 

214 yamldata = yaml.safe_load(data) 

215 except yaml.YAMLError as err: 

216 raise TaskSyntaxError(err) 

217 

218 if 'timeout' in yamldata: 

219 try: 

220 timeout = int(yamldata['timeout']) 

221 if not timeout > 0: 

222 raise ValueError 

223 except ValueError: 

224 raise TaskSyntaxError(f"in task {name}: invalid timeout") 

225 except TypeError: 

226 raise TaskSyntaxError(f"in task {name}: timeout not integer") 

227 else: 

228 timeout = 30 

229 

230 if 'max_retry' in yamldata: 

231 try: 

232 max_retry = int(yamldata['max_retry']) 

233 except ValueError: 

234 raise TaskSyntaxError(f"in task {name}: invalid max_retry") 

235 except TypeError: 

236 raise TaskSyntaxError(f"in task {name}: max_retry not integer") 

237 else: 

238 max_retry = 5 

239 

240 if 'changes' not in yamldata: 

241 raise TaskSyntaxError(f"in task {name}: " 

242 "missing changes list") 

243 

244 if 'paths' not in yamldata: 

245 raise TaskSyntaxError(f"in task {name}: " 

246 "missing paths list") 

247 

248 if not isinstance(yamldata['paths'], list): 

249 raise TaskSyntaxError(f"in task {name}: " 

250 "paths must be a list") 

251 

252 for path in yamldata['paths']: 

253 if not isinstance(path, str): 

254 raise TaskSyntaxError(f"in task {name}: " 

255 "paths must be strings") 

256 

257 if len(yamldata['paths']) < 1: 

258 raise TaskSyntaxError(f"in task {name}: " 

259 "at least one path required") 

260 

261 if 'attrs' in yamldata: 

262 if not isinstance(yamldata['attrs'], dict): 

263 raise TaskSyntaxError(f"in task {name}: " 

264 "attrs must be a dictionary") 

265 

266 imp_changes = ['added', 'modified', 'deleted'] 

267 for change in yamldata['changes']: 

268 if change not in imp_changes: 

269 raise TaskSyntaxError(f"in task {name}: " 

270 f"invalid change {change}") 

271 

272 if not isinstance(yamldata['changes'], list): 

273 raise TaskSyntaxError(f"in task {name}: " 

274 "changes must be a list") 

275 

276 changes = [getattr(watchfiles.Change, change) 

277 for change in yamldata['changes']] 

278 

279 paths = yamldata["paths"] 

280 attrs = yamldata['attrs'] if 'attrs' in yamldata else None 

281 return cls(name, changes, callbacks, paths, timeout, max_retry, attrs) 

282 

283 

284class TaskRunner: 

285 """ 

286 `Asyncio` loop handler. Acts as a functor. 

287 """ 

288 def __init__(self, tasks: TaskList, testenv: bool = False): 

289 logger.debug('task runner started...') 

290 self.loop = asyncio.get_event_loop() 

291 self.tasks = [] 

292 

293 if testenv: 

294 self.loop = asyncio.new_event_loop() 

295 self.tasks.append( 

296 self.loop.create_task(self.cancle_tasks())) 

297 

298 signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) 

299 for s in signals: 

300 self.loop.add_signal_handler(s, lambda s=s: asyncio.create_task( 

301 self.signal_handler(s))) 

302 

303 for task in tasks: 

304 for callback in task.callbacks: 

305 self.tasks.append( 

306 self.loop.create_task(task(callback))) 

307 

308 async def __call__(self): 

309 for task in asyncio.as_completed(self.tasks): 

310 try: 

311 await task 

312 except RuntimeError as err: 

313 if str(err) == "Already borrowed": 

314 # Suppress RuntimeError("Already borrowed"), to 

315 # work around this issue: 

316 # https://github.com/samuelcolvin/watchfiles/issues/200 

317 pass 

318 else: 

319 raise 

320 except asyncio.CancelledError: 

321 raise 

322 except Exception: 

323 raise 

324 

325 async def signal_handler(self, sig: signal.Signals): 

326 """ 

327 Signal handler. 

328 """ 

329 logger.debug(f'received {sig.name}') 

330 for task in self.tasks: 

331 task.cancel() 

332 

333 async def cancle_tasks(self, delay=2): 

334 """ 

335 Cancle tasks (for testing purposes) 

336 """ 

337 await asyncio.sleep(delay) 

338 for task in self.tasks: 

339 if task is not asyncio.current_task(): 

340 task.cancel()