From eed4cfe39fb3dcdb3390bb9580f91be48167a6a2 Mon Sep 17 00:00:00 2001
From: ale <ale@incal.net>
Date: Sun, 30 Oct 2011 10:48:08 +0000
Subject: [PATCH] add a faster MP3.get_many() method

---
 server/djrandom/frontend/views.py  |  2 +-
 server/djrandom/model/mp3.py       |  6 ++++++
 server/djrandom/test/test_model.py | 11 +++++++++++
 3 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/server/djrandom/frontend/views.py b/server/djrandom/frontend/views.py
index 69e4b52..80de066 100644
--- a/server/djrandom/frontend/views.py
+++ b/server/djrandom/frontend/views.py
@@ -42,7 +42,7 @@ def autocomplete_search():
 def songs_fragment():
     hashes = request.form.get('h', '')
     if hashes:
-        mp3s = [MP3.query.get(h) for h in hashes.split(',')]
+        mp3s = MP3.get_many(hashes.split(','))
     else:
         mp3s = []
     return render_template('songs_fragment.html', songs=mp3s)
diff --git a/server/djrandom/model/mp3.py b/server/djrandom/model/mp3.py
index 2e50cd4..b8cc718 100644
--- a/server/djrandom/model/mp3.py
+++ b/server/djrandom/model/mp3.py
@@ -69,6 +69,12 @@ class MP3(Base):
             data['track_num'] = self.track_num
         return data
 
+    @classmethod
+    def get_many(cls, hashes):
+        order = dict((sha1, idx) for idx, sha1 in enumerate(hashes))
+        mp3s = cls.query.filter(cls.sha1.in_(hashes))
+        return sorted(mp3s, key=lambda x: order[x.sha1])
+
     def mark_as_duplicate(self, duplicate_of):
         self.state = self.DUPLICATE
         self.duplicate_of = duplicate_of
diff --git a/server/djrandom/test/test_model.py b/server/djrandom/test/test_model.py
index 428c9df..0ad2dc7 100644
--- a/server/djrandom/test/test_model.py
+++ b/server/djrandom/test/test_model.py
@@ -75,6 +75,17 @@ class MP3Test(DbTestCase):
         result_ids = set(x.sha1 for x in results)
         self.assertEquals(set(['1001', '1002']), result_ids)
 
+    def test_mp3_get_many(self):
+        for i in range(1, 4):
+            mp3, _ = create_mp3(unicode(i))
+            Session.add(mp3)
+        Session.commit()
+
+        hashes = [u'3', u'1', u'2']
+        results = MP3.get_many(hashes)
+        result_hashes = [x.sha1 for x in results]
+        self.assertEquals(hashes, result_hashes)
+
     def test_mp3_get_with_bad_metadata(self):
         pass
 
-- 
GitLab