# -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals from nltk.tag import hmm def _wikipedia_example_hmm(): # Example from wikipedia # (http://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm) states = ['rain', 'no rain'] symbols = ['umbrella', 'no umbrella'] A = [[0.7, 0.3], [0.3, 0.7]] # transition probabilities B = [[0.9, 0.1], [0.2, 0.8]] # emission probabilities pi = [0.5, 0.5] # initial probabilities seq = ['umbrella', 'umbrella', 'no umbrella', 'umbrella', 'umbrella'] seq = list(zip(seq, [None] * len(seq))) model = hmm._create_hmm_tagger(states, symbols, A, B, pi) return model, states, symbols, seq def test_forward_probability(): from numpy.testing import assert_array_almost_equal # example from p. 385, Huang et al model, states, symbols = hmm._market_hmm_example() seq = [('up', None), ('up', None)] expected = [[0.35, 0.02, 0.09], [0.1792, 0.0085, 0.0357]] fp = 2 ** model._forward_probability(seq) assert_array_almost_equal(fp, expected) def test_forward_probability2(): from numpy.testing import assert_array_almost_equal model, states, symbols, seq = _wikipedia_example_hmm() fp = 2 ** model._forward_probability(seq) # examples in wikipedia are normalized fp = (fp.T / fp.sum(axis=1)).T wikipedia_results = [ [0.8182, 0.1818], [0.8834, 0.1166], [0.1907, 0.8093], [0.7308, 0.2692], [0.8673, 0.1327], ] assert_array_almost_equal(wikipedia_results, fp, 4) def test_backward_probability(): from numpy.testing import assert_array_almost_equal model, states, symbols, seq = _wikipedia_example_hmm() bp = 2 ** model._backward_probability(seq) # examples in wikipedia are normalized bp = (bp.T / bp.sum(axis=1)).T wikipedia_results = [ # Forward-backward algorithm doesn't need b0_5, # so .backward_probability doesn't compute it. # [0.6469, 0.3531], [0.5923, 0.4077], [0.3763, 0.6237], [0.6533, 0.3467], [0.6273, 0.3727], [0.5, 0.5], ] assert_array_almost_equal(wikipedia_results, bp, 4) def setup_module(module): from nose import SkipTest try: import numpy except ImportError: raise SkipTest("numpy is required for nltk.test.test_hmm")