Skip to content
Snippets Groups Projects
Select Git revision
  • 9cc39d1884a2f2b20a3981d6f82caf182c0c279c
  • master default protected
  • renovate/github.com-mattn-go-sqlite3-1.x
  • renovate/golang-1.x
  • renovate/github.com-oschwald-maxminddb-golang-1.x
  • renovate/github.com-prometheus-client_golang-1.x
  • renovate/google.golang.org-protobuf-1.x
  • renovate/golang.org-x-sync-digest
  • renovate/github.com-d5-tengo-v2-2.x
  • renovate/google.golang.org-grpc-1.x
  • renovate/gopkg.in-yaml.v3-3.x
  • renovate/github.com-golang-migrate-migrate-v4-4.x
  • renovate/github.com-google-go-cmp-0.x
13 results

driver.go

Blame
  • markov.py 3.25 KiB
    import os
    import optparse
    import logging
    import cPickle as pickle
    import random
    from djrandom import utils
    from djrandom.database import init_db
    from djrandom.model.mp3 import PlayLog
    
    
    log = logging.getLogger(__name__)
    
    
    class MarkovModel(object):
    
        def __init__(self):
            self._hash2i = {'nil': 0}
            self._i2hash = ['nil']
            self._map = {}
            self._rnd = random.Random()
    
        def _to_i(self, sha1):
            if sha1 is None:
                sha1 = 'nil'
            if sha1 not in self._hash2i:
                n = len(self._i2hash)
                self._i2hash.append(sha1)
                self._hash2i[sha1] = n
            else:
                n = self._hash2i[sha1]
            return n
    
        def save(self, filename):
            with open(filename, 'wb') as fd:
                state = (self._hash2i, self._i2hash, self._map)
                pickle.dump(state, fd, pickle.HIGHEST_PROTOCOL)
    
        def load(self, filename):
            with open(filename, 'rb') as fd:
                self._hash2i, self._i2hash, self._map = pickle.load(fd)
    
        def create(self, source):
            for sha1, prev in source:
                n = self._to_i(sha1)
                prev_n = tuple(self._to_i(x) for x in prev)
    
                target_map = self._map.setdefault(prev_n, {})
                if n in target_map:
                    target_map[n] += 1
                else:
                    target_map[n] = 1
    
        def normalize(self):
            norm_map = {}
            for key, target_map in self._map.iteritems():
                # We will explicitly drop this target to avoid repeating
                # the same song more than once.
                last_song = key[-1]
    
                norm_vec = []
                tot = cur = 0
                for target, count in target_map.iteritems():
                    if target != last_song:
                        tot += count
                for target, count in target_map.iteritems():
                    if target != last_song:
                        cur += float(count) / tot
                        norm_vec.append((cur, target))
                norm_map[key] = norm_vec
    
            self._map = norm_map
    
        def suggest(self, prev):
            prev_n = tuple(self._to_i(x) for x in prev)
            if prev_n not in self._map:
                return None
            r = self.random.random()
            for off, value in self._norm_map[prev_n]:
                if off > r:
                    return self._i2hash[value]
    
        def generate_sequence(self, prev, n, count):
            if len(prev) < n:
                prev = ([None] * (n - len(prev))) + prev
            out = []
            for i in xrange(count):
                song = self.suggest(prev)
                prev.pop(0)
                prev.append(song)
                out.append(song)
            return out
    
    
    def main():
        parser = optparse.OptionParser()
        parser.add_option('--db_url')
        parser.add_option('--output', default='markov.dat')
        utils.read_config_defaults(
            parser, os.getenv('DJRANDOM_CONF', '/etc/djrandom.conf'))
        opts, args = parser.parse_args()
        if not opts.db_url:
            parser.error('Must provide --db_url')
        if args:
            parser.error('Too many arguments')
    
        init_db(opts.db_url)
    
        markov_model = MarkovModel()
        markov_model.create(PlayLog.generate_tuples())
        markov_model.normalize()
        markov_model.save(opts.output)
    
        from pprint import pprint
        pprint(markov_model._map)
    
    
    if __name__ == '__main__':
        main()