diff --git a/ecopilot_srcindex/http.py b/ecopilot_srcindex/http.py index f207c0180dafbfd4b095cac602b4b85fa29f28ad..e0067d54bdc00daf97fefcd8c9fdf232f07ffaee 100644 --- a/ecopilot_srcindex/http.py +++ b/ecopilot_srcindex/http.py @@ -2,6 +2,7 @@ import argparse import git import logging import os +import threading from flask import Flask, request, make_response from werkzeug.serving import make_server @@ -10,6 +11,7 @@ from ecopilot_srcindex.repomap import RepoMap app = Flask(__name__) +mutex = threading.Lock() @app.route('/', methods=('POST',)) @@ -33,20 +35,53 @@ class Service: self.repo_map_cache = {} def handle_request(self, git_root, chat_fnames): + if app.watchdog: + app.watchdog.reset() + repo = git.Repo(git_root, search_parent_directories=True) git_root = repo.working_tree_dir abs_src_paths = [ os.path.join(git_root, entry) for entry in repo.git.ls_files().split('\n')] - if git_root not in self.repo_map_cache: - self.repo_map_cache[git_root] = RepoMap( - map_tokens=self.map_tokens, - max_context_window=self.max_context_size, - main_model=self.model, - root=git_root) - repo_map = self.repo_map_cache[git_root] - return repo_map.get_ranked_tags_map(chat_fnames, abs_src_paths) + with mutex: + if git_root not in self.repo_map_cache: + self.repo_map_cache[git_root] = RepoMap( + map_tokens=self.map_tokens, + max_context_window=self.max_context_size, + main_model=self.model, + root=git_root) + repo_map = self.repo_map_cache[git_root] + return repo_map.get_ranked_tags_map(chat_fnames, abs_src_paths) + + +class _watchdog(): + + def __init__(self, fn, timeout): + self._timeout = timeout + self._fn = fn + self._start() + + def _start(self): + self._timer = threading.Timer(self._timeout, self._fn) + self._timer.start() + + def reset(self): + self._timer.cancel() + self._start() + + +def _run_server_with_watchdog(srv, app, timeout): + if timeout <= 0: + app.watchdog = None + else: + def _stop(): + srv.shutdown() + + app.watchdog = _watchdog(_stop, timeout) + + srv.log_startup() + srv.serve_forever() def main(): @@ -66,6 +101,9 @@ HTTP service for contextual code indexing. parser.add_argument( '--map-tokens', metavar='N', type=int, default=1024, help='Map tokens limit (default: %(default)s)') + parser.add_argument( + '--timeout', metavar='SECONDS', type=int, default=600, + help='Exit after idle time (0 to disable)') args = parser.parse_args() logging.basicConfig(level=logging.DEBUG) @@ -81,8 +119,7 @@ HTTP service for contextual code indexing. # systemd activated socket (which we expect will be fd 3). srv = make_server( '127.0.0.1', 7894, app, fd=3) - srv.log_startup() - srv.serve_forever() + _run_server_with_watchdog(srv, app, args.timeout) if __name__ == '__main__':