diff --git a/server/djrandom/model/markov.py b/server/djrandom/model/markov.py index bc956019a73ffac536688ced1e0a8ee4571d26cf..8bc7ff7ada7f2e0001f80536eeb926df44f36fa8 100644 --- a/server/djrandom/model/markov.py +++ b/server/djrandom/model/markov.py @@ -14,18 +14,20 @@ log = logging.getLogger(__name__) class MarkovModel(object): def __init__(self): - self._hash2i = {} - self._i2hash = [] + self._hash2i = {'nil': 0} + self._i2hash = ['nil'] self._map = {} self._rnd = random.Random() def _to_i(self, sha1): + if sha1 is None: + sha1 = 'nil' if sha1 not in self._hash2i: n = len(self._i2hash) self._i2hash.append(sha1) self._hash2i[sha1] = n else: - n = self._i2hash[sha1] + n = self._hash2i[sha1] return n def save(self, filename): @@ -73,6 +75,7 @@ class MarkovModel(object): def main(): parser = optparse.OptionParser() parser.add_option('--db_url') + parser.add_option('--output', default='markov.dat') utils.read_config_defaults( parser, os.getenv('DJRANDOM_CONF', '/etc/djrandom.conf')) opts, args = parser.parse_args() @@ -86,7 +89,10 @@ def main(): markov_model = MarkovModel() markov_model.create(PlayLog.generate_tuples()) markov_model.normalize() - markov_model.save('markov.dat') + markov_model.save(opts.output) + + from pprint import pprint + pprint(markov_model._map) if __name__ == '__main__':