Skip to content
Snippets Groups Projects
Commit 7d187c96 authored by ale's avatar ale
Browse files

add a standalone server to run the timbre vector search db

parent a9591397
No related branches found
No related tags found
No related merge requests found
import heapq
import logging
import marsyas
import optparse
import os
import threading
import time
from djrandom import daemonize
from djrandom import utils
from djrandom.mood import marsyas_utils
from djrandom.mood import marsyas_c_utils
from djrandom.model.mp3 import MP3, Features
from djrandom.database import Session, init_db
from flask import Flask, request, abort, jsonify
from sqlalchemy import select
log = logging.getLogger(__name__)
app = Flask(__name__)
class TimbreDb(object):
def __init__(self):
self._db = []
def load_data(self, dataiter):
self._db = list(dataiter)
log.debug('timbre data loaded')
def search(self, rv, n=10):
scores = [(0, None)] * n
for id, vector in self._db:
score = marsyas_c_utils.euclidean_distance(vector, rv)
heapq.heappushpop(scores, (score, id))
return scores
class DbLoader(threading.Thread):
def __init__(self, engine, timbre_db):
threading.Thread.__init__(self)
self._engine = engine
self._timbre_db = timbre_db
def _updatedb(self):
# Bypass the SQLAlchemy ORM, and just run a huge SELECT query
# to reduce the memory footprint.
q = select([Features.sha1, Features.timbre_vector],
(MP3.sha1 == Features.sha1)
& (MP3.state == MP3.READY)
& (MP3.has_features == True))
features_iter = (
(x.sha1, marsyas_utils.deserialize_realvec(x.timbre_vector))
for x in self._engine.execute(q))
self._timbre_db.load_data(features_iter)
def run(self):
while True:
try:
self._updatedb()
except Exception, e:
log.error('error updating the features db: %s', e)
time.sleep(3600)
@app.teardown_request
def shutdown_dbsession(exception=None):
Session.remove()
@app.route('/search/<sha1>')
def search_handler(sha1):
n = int(request.args.get('n', 10))
mp3 = MP3.query.get(sha1)
if not mp3 or not mp3.has_features:
abort(404)
vector = marsyas_utils.deserialize_realvec(mp3.features.timbre_vector)
return jsonify(results=[
{'score': x[0], 'sha1': x[1]}
for x in app.timbre_db.search(vector, n)])
def run_timbre_db(db_url, port):
engine = init_db(db_url)
timbre_db = TimbreDb()
loader = DbLoader(engine, timbre_db)
loader.setDaemon(True)
loader.start()
app.timbre_db = timbre_db
app.run(port)
def main():
parser = optparse.OptionParser()
parser.add_option('--db_url')
parser.add_option('--port', type='int', default=4001)
daemonize.add_standard_options(parser)
utils.read_config_defaults(
parser, os.getenv('DJRANDOM_CONF', '/etc/djrandom.conf'))
opts, args = parser.parse_args()
if not opts.db_url:
parser.error('Must provide --db_url')
if args:
parser.error('Too many arguments')
daemonize.daemonize(opts, run_timbre_db,
(opts.db_url, opts.port))
if __name__ == '__main__':
main()
......@@ -33,6 +33,7 @@ setup(
"djrandom-metadata-fixer = djrandom.metadata_fixer.metadata_fixer:main",
"djrandom-solr-fixer = djrandom.model.verify:main",
"djrandom-mood-scanner = djrandom.mood.mood_scanner:main",
"djrandom-mood-db = djrandom.mood.mood_db:main",
],
},
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment