Lab 2 - Symbolic search#
Imports#
import re
import itertools
import math
import time
import random
from collections import Counter
from pprint import pprint
import numpy as np
import pandas as pd
from numpy.random import choice, randint
from IPython.display import HTML, display, clear_output
import matplotlib.pyplot as plt
import ipywidgets as widgets
from symbolic_utilities import define_bs_DSL, define_lt_DSL
# Utilities for plotting
from symbolic_utilities import \
progress, compute_global_limits_mh, compute_global_limits_smc, plot_mh_trace_upto, plot_state_2d
# Utilities for enumeration
from symbolic_utilities import enumerate_full_sentences, enumerate_trees, enumerate_full_sentences_bottomup, \
enumerate_trees_bottomup, compute_likelihood_bs, compute_likelihood_lt
# Various utils for manipulating trees
from symbolic_utilities import \
generate_tree, tree_to_sentence, get_nonterminal_nodes, set_subtree, mutate_tree, \
compute_tree_probability, compute_unnormalized_posterior
# MHMC sampler
from symbolic_utilities import propose_tree, get_coordinates, mh_sampler
from symbolic_utilities import smc_sampler
DSL I: binary strings (conceptually easy peasy)#
bsgrammar, BS_NONTERMINALS, BS_TERMINALS, bs_eval_dict = define_bs_DSL()
# Suppose our grammar produced the string expression below:
expr_str = "C(T(1),R(S(I(0,B(1)))))"
# Evaluate the expression to obtain a binary string.
string = eval(expr_str, bs_eval_dict)
print("Generated expression:", string)
Generated expression: 111011
To get a sense of the variety of strings you can get with this grammar, you can use this function that samples a tree from the PCFG:
# Generate a random string from the grammar.
tree = generate_tree("S", bsgrammar)
sentence_tokens = tree_to_sentence(tree)
expression = "".join(sentence_tokens)
print("Random sentence:", expression)
print("String: ", eval(expression, bs_eval_dict))
Random sentence: N(0)
String: 1
DSL II: list transformations (conceptually not so trivial!)#
ltgrammar, LT_NONTERMINALS, LT_TERMINALS, lt_eval_dict = define_lt_DSL()
# Suppose our grammar produced the string expression below:
expr_str = "compose(map_(minus(2)), filter_(and_(even, gt(2))))"
# Evaluate the expression to obtain a transformation function.
transformation_fn = eval(expr_str, lt_eval_dict)
# Apply the transformation to a sample list of integers.
sample_list = [3, 1, 4, 1, 5, 9, 2, 6]
result = transformation_fn(sample_list)
print("Generated expression:", expr_str)
print("Input list:", sample_list)
print("Result after transformation:", result)
Generated expression: compose(map_(minus(2)), filter_(and_(even, gt(2))))
Input list: [3, 1, 4, 1, 5, 9, 2, 6]
Result after transformation: [4]
# Generate a random transformation expression from the grammar.
tree = generate_tree("T", ltgrammar)
sentence_tokens = tree_to_sentence(tree)
expression = "".join(sentence_tokens)
print("Random transformation:", expression)
Random transformation: sort
Enumeration#
Let’s start looking at the search algorithms. We first produce some data that we can fit with our search strategies:
expr_str = "compose(map_(minus(2)), filter_(gt(2)))"
transformation_fn = eval(expr_str, lt_eval_dict)
inputs = [
[2, 3, 5, 6, 1],
[6, 2, 3, 1],
[9, 1, 8, 3, 5, 2]
]
data = [
(inp, transformation_fn(inp))
for inp in inputs
]
data
[([2, 3, 5, 6, 1], [3, 4]),
([6, 2, 3, 1], [4]),
([9, 1, 8, 3, 5, 2], [7, 6, 3])]
First, let’s look at simple top-down enumeration with the binary string grammar:
%%time
sentences = dict()
# Enumerate sentences from the start symbol "S" with a depth limit.
for tree in enumerate_trees("S", bsgrammar, max_depth=4):
string = "".join(tree_to_sentence(tree))
sentences[eval(string, bs_eval_dict)] = compute_unnormalized_posterior(
tree,
bsgrammar,
"001001",
bs_eval_dict,
compute_likelihood_bs
)
# Print the most probable solutions
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 3.05 s, sys: 0 ns, total: 3.05 s
Wall time: 3.05 s
[('001001', 3.012736478083201e-06),
(0, 2e-07),
(1, 2e-07),
('000000', 1.9211920200000038e-08),
('011011', 7.684768080000018e-10),
('101101', 3.0739072320000074e-10),
('00', 2.0000000000000006e-10),
('000', 2.0000000000000006e-10),
('0000', 2.0000000000000006e-10),
('11', 2.0000000000000006e-10)]
And now bottom-up enumeration:
%%time
sentences = dict()
for tree in enumerate_trees_bottomup("S", bsgrammar, max_level=4):
string = "".join(tree_to_sentence(tree))
sentences[(string, eval(string, bs_eval_dict))] = compute_unnormalized_posterior(
tree,
bsgrammar,
"001001",
bs_eval_dict,
compute_likelihood_bs,
lik_params={
# 1 - probability of noise flipping a bit
'match_prob': 1.0,
# probability of any string whose length != observed
'length_mismatch_prob': -np.inf
}
)
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 4.9 s, sys: 24.7 ms, total: 4.93 s
Wall time: 4.93 s
[(('D(C(D(0),1))', '001001'), 8.000000000000003e-05),
(('D(C(0,C(0,1)))', '001001'), 3.200000000000001e-05),
(('D(C(C(0,0),1))', '001001'), 3.200000000000001e-05),
(('D(C(D(0),R(1)))', '001001'), 8.000000000000005e-06),
(('D(C(D(0),N(0)))', '001001'), 8.000000000000005e-06),
(('D(C(C(0,0),R(1)))', '001001'), 3.2000000000000015e-06),
(('D(C(C(0,0),N(0)))', '001001'), 3.2000000000000015e-06),
(('D(C(R(0),C(0,1)))', '001001'), 3.2000000000000015e-06),
(('D(C(N(1),C(0,1)))', '001001'), 3.2000000000000015e-06),
(('C(C(D(0),1),C(D(0),1))', '001001'), 1.280000000000001e-07)]
Bottom-up is quite a bit slower!
Now we can apply our little trick of progressively removing expressions that are synonymous with previous found ones, and see if it speeds up the search:
%%time
sentences = dict()
for tree in enumerate_trees_bottomup("S", bsgrammar, max_level=4, eval_env=bs_eval_dict, are_same=lambda x,y: x==y):
string = "".join(tree_to_sentence(tree))
sentences[(string, eval(string, bs_eval_dict))] = compute_unnormalized_posterior(
tree,
bsgrammar,
"001001",
bs_eval_dict,
compute_likelihood_bs,
lik_params={
'match_prob': 1.0,
'length_mismatch_prob': -np.inf
}
)
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 259 ms, sys: 0 ns, total: 259 ms
Wall time: 257 ms
[(('C(C(0,0),C(C(1,0),C(0,1)))', '001001'), 2.0480000000000015e-08),
(('C(T(0),T(0))', '000000'), 0.0),
(('C(T(0),T(1))', '000111'), 0.0),
(('C(T(1),T(0))', '111000'), 0.0),
(('C(T(1),T(1))', '111111'), 0.0),
(('T(C(0,1))', '010101'), 0.0),
(('T(C(1,0))', '101010'), 0.0),
(('C(0,C(C(0,1),T(0)))', '001000'), 0.0),
(('C(0,C(C(0,1),T(1)))', '001111'), 0.0),
(('C(0,C(C(1,0),T(0)))', '010000'), 0.0)]
Woah, that was a lot faster!
Now suppose that you got a bit cocky and you wanted to induce a formula for a more complex expression:
%%time
sentences = dict()
for tree in enumerate_trees_bottomup("S", bsgrammar, max_level=4, eval_env=bs_eval_dict, are_same=lambda x,y: x==y):
string = "".join(tree_to_sentence(tree))
sentences[(string, eval(string, bs_eval_dict))] = compute_unnormalized_posterior(
tree,
bsgrammar,
"00100100010",
bs_eval_dict,
compute_likelihood_bs,
lik_params={
'match_prob': 1.0,
'length_mismatch_prob': -np.inf
}
)
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 270 ms, sys: 0 ns, total: 270 ms
Wall time: 268 ms
[(('C(C(0,0),T(T(0)))', '00000000000'), 0.0),
(('C(C(0,0),T(T(1)))', '00111111111'), 0.0),
(('C(C(0,1),T(T(0)))', '01000000000'), 0.0),
(('C(C(0,1),T(T(1)))', '01111111111'), 0.0),
(('C(C(1,0),T(T(0)))', '10000000000'), 0.0),
(('C(C(1,0),T(T(1)))', '10111111111'), 0.0),
(('C(C(1,1),T(T(0)))', '11000000000'), 0.0),
(('C(C(1,1),T(T(1)))', '11111111111'), 0.0),
(('C(C(C(0,0),T(0)),C(T(0),T(1)))', '00000000111'), 0.0),
(('C(C(C(0,0),T(0)),C(T(1),T(0)))', '00000111000'), 0.0)]
You might think “Uhm, we need to explore deeper than 4 levels of depth to find at least one solution” - this is not so easy. Try it out!
Now let’s look at the top-down enumeration strategy for the more complicated domain of list transformation:
%%time
sentences = dict()
for tree in enumerate_trees("T", ltgrammar, max_depth=4):
sentences["".join(tree_to_sentence(tree))] = compute_unnormalized_posterior(
tree,
ltgrammar,
data,
lt_eval_dict,
compute_likelihood_lt
)
# order by unnormalized posterior
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 22.7 ms, sys: 9.81 ms, total: 32.5 ms
Wall time: 31.4 ms
[('compose(map_(minus(2)),filter_(gt(1)))', 5.648880896406001e-05),
('compose(map_(minus(2)),filter_(gt(2)))', 5.648880896406001e-05),
('compose(filter_(gt(3)),map_(minus(2)))', 5.648880896406001e-05),
('compose(filter_(gt(4)),map_(minus(2)))', 5.648880896406001e-05),
('compose(truncate(3),filter_(not_(even)))', 1.9800000000000042e-14),
('filter_(gt(3))', 6.000000000000033e-15),
('filter_(gt(4))', 6.000000000000033e-15),
('compose(filter_(gt(1)),filter_(not_(even)))', 5.940000000000029e-15),
('compose(filter_(gt(2)),filter_(not_(even)))', 5.940000000000029e-15),
('compose(filter_(not_(even)),filter_(gt(1)))', 5.940000000000029e-15)]
And bottom-up:
%%time
sentences = dict()
for tree in enumerate_trees_bottomup("T", ltgrammar, max_level=4):
sentences["".join(tree_to_sentence(tree))] = compute_unnormalized_posterior(
tree,
ltgrammar,
data,
lt_eval_dict,
compute_likelihood_lt
)
sorted(sentences.items(), key=lambda item: -item[1])[:10]
CPU times: user 324 ms, sys: 0 ns, total: 324 ms
Wall time: 323 ms
[('compose(map_(minus(2)),filter_(gt(1)))', 5.648880896406001e-05),
('compose(map_(minus(2)),filter_(gt(2)))', 5.648880896406001e-05),
('compose(filter_(gt(3)),map_(minus(2)))', 5.648880896406001e-05),
('compose(filter_(gt(4)),map_(minus(2)))', 5.648880896406001e-05),
('compose(truncate(3),filter_(not_(even)))', 1.9800000000000042e-14),
('filter_(gt(3))', 6.000000000000033e-15),
('filter_(gt(4))', 6.000000000000033e-15),
('compose(filter_(gt(1)),filter_(not_(even)))', 5.940000000000029e-15),
('compose(filter_(gt(2)),filter_(not_(even)))', 5.940000000000029e-15),
('compose(filter_(not_(even)),filter_(gt(1)))', 5.940000000000029e-15)]
Unfortunately, there is no easy way to check whether two list transformation functions are equivalent, so we cannot apply the little trick of successive synonym pruning.
Metropolis-Hastings#
Now we can run the MH algorithm, with subtree-regeneration as our proposal function.
# Run MH sampler
num_iterations = 5000
mh_trace = mh_sampler(
bsgrammar,
# a datapoint to fit
"1110010",
eval_dict=bs_eval_dict,
starting='S',
num_iterations=num_iterations,
likelihoodf=compute_likelihood_bs
)
sentences = ["".join(tree_to_sentence(x['current_tree'])) for x in mh_trace]
c = Counter()
for s in sentences:
c[s] += 1
c.most_common(10)
[('C(T(1),C(0,C(C(0,1),0)))', 466),
('C(T(1),C(0,C(R(C(1,0)),0)))', 250),
('C(T(1),C(D(0),C(1,0)))', 224),
('C(T(1),C(D(0),C(1,R(0))))', 160),
('C(T(1),C(R(0),C(R(C(1,0)),0)))', 107),
('C(T(1),C(D(R(0)),C(1,0)))', 92),
('C(T(1),C(0,C(C(R(0),1),0)))', 84),
('C(N(N(T(1))),C(D(R(R(0))),C(1,N(1))))', 80),
('C(R(T(1)),C(D(0),C(1,0)))', 77),
('C(T(1),C(0,C(C(N(1),1),0)))', 76)]
current_index = 0
global_xlims, global_ylims = compute_global_limits_mh(mh_trace)
# Create interactive buttons.
button_forward = widgets.Button(description="Forward")
button_backward = widgets.Button(description="Backward")
state_label = widgets.Label(value=f"Iteration: {current_index} / {len(mh_trace)-1}")
output_plot = widgets.Output()
def update_plot():
with output_plot:
clear_output(wait=True)
plot_mh_trace_upto(mh_trace, current_index, global_xlims, global_ylims)
state_label.value = f"Iteration: {current_index} / {len(mh_trace)-1}"
def on_forward_clicked(b):
global current_index
if current_index < len(mh_trace) - 1:
current_index += 1
update_plot()
def on_backward_clicked(b):
global current_index
if current_index > 0:
current_index -= 1
update_plot()
# Create a Play widget.
play_widget = widgets.Play(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
# time in milliseconds between steps
interval=200,
description="Press play",
disabled=False
)
# Create a slider (if you want to display it too).
iteration_slider = widgets.IntSlider(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
description='Iteration:',
continuous_update=True
)
# Link the play widget to the slider so that they move together.
widgets.jslink((play_widget, 'value'), (iteration_slider, 'value'))
# Now, update your update_plot function to observe changes to the slider:
def on_slider_change(change):
global current_index
current_index = change['new']
update_plot()
iteration_slider.observe(on_slider_change, names='value')
button_forward.on_click(on_forward_clicked)
button_backward.on_click(on_backward_clicked)
# Finally, display the play widget (and slider) alongside your forward/backward buttons.
controls = widgets.HBox([button_backward, play_widget, iteration_slider, button_forward, state_label])
display(controls, output_plot)
update_plot()
Remember our complicated string above? Let’s see if MHMC can manage it!
# Run MH sampler
num_iterations = 10000
mh_trace = mh_sampler(
bsgrammar,
# a datapoint to fit
"00100100010",
eval_dict=bs_eval_dict,
starting='S',
num_iterations=num_iterations,
likelihoodf=compute_likelihood_bs
)
sentences = ["".join(tree_to_sentence(x['current_tree'])) for x in mh_trace]
c = Counter()
for s in sentences:
c[s] += 1
c.most_common(10)
[('1', 1626),
('0', 1452),
('N(0)', 189),
('N(C(R(N(C(R(C(C(1,C(D(D(0)),D(0))),1)),D(0)))),1))', 180),
('D(1)', 179),
('N(C(R(N(C(R(C(C(1,C(C(0,C(0,R(C(0,1)))),N(C(1,1)))),1)),C(N(1),0)))),1))',
170),
('N(1)', 163),
('D(0)', 158),
('T(1)', 156),
('R(1)', 144)]
The algorithm finds one of the formulas encoding the string (But doesn’t spend so much time on it because it has a very low prior):
eval('N(C(R(N(C(R(C(C(1,C(C(0,C(0,R(C(0,1)))),N(C(1,1)))),1)),C(N(1),0)))),1))', bs_eval_dict)
'00100100010'
Let’s plot the process:
current_index = 0
global_xlims, global_ylims = compute_global_limits_mh(mh_trace)
# Create interactive buttons.
button_forward = widgets.Button(description="Forward")
button_backward = widgets.Button(description="Backward")
state_label = widgets.Label(value=f"Iteration: {current_index} / {len(mh_trace)-1}")
output_plot = widgets.Output()
def update_plot():
with output_plot:
clear_output(wait=True)
plot_mh_trace_upto(mh_trace, current_index, global_xlims, global_ylims)
state_label.value = f"Iteration: {current_index} / {len(mh_trace)-1}"
def on_forward_clicked(b):
global current_index
if current_index < len(mh_trace) - 1:
current_index += 1
update_plot()
def on_backward_clicked(b):
global current_index
if current_index > 0:
current_index -= 1
update_plot()
# Create a Play widget.
play_widget = widgets.Play(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
# time in milliseconds between steps
interval=200,
description="Press play",
disabled=False
)
# Create a slider (if you want to display it too).
iteration_slider = widgets.IntSlider(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
description='Iteration:',
continuous_update=True
)
# Link the play widget to the slider so that they move together.
widgets.jslink((play_widget, 'value'), (iteration_slider, 'value'))
# Now, update your update_plot function to observe changes to the slider:
def on_slider_change(change):
global current_index
current_index = change['new']
update_plot()
iteration_slider.observe(on_slider_change, names='value')
button_forward.on_click(on_forward_clicked)
button_backward.on_click(on_backward_clicked)
# Finally, display the play widget (and slider) alongside your forward/backward buttons.
controls = widgets.HBox([button_backward, play_widget, iteration_slider, button_forward, state_label])
display(controls, output_plot)
update_plot()
Now we can define some data with the transformation DSL to fit:
expr_str = "compose(map_(minus(2)), filter_(gt(2)))"
transformation_fn = eval(expr_str, lt_eval_dict)
inputs = [
[2, 3, 5, 6, 1],
[6, 2, 3, 1],
[9, 1, 8, 3, 5, 2]
]
ltdata = [
(inp, transformation_fn(inp))
for inp in inputs
]
Let’s run the MH algorith, with subtree-regeneration as our proposal function.
# Run MH sampler
num_iterations = 5000
mh_trace = mh_sampler(
ltgrammar,
ltdata,
eval_dict=lt_eval_dict,
starting='T',
num_iterations=num_iterations,
likelihoodf=compute_likelihood_lt
)
And plot:
current_index = 0
global_xlims, global_ylims = compute_global_limits_mh(mh_trace)
# Create interactive buttons.
button_forward = widgets.Button(description="Forward")
button_backward = widgets.Button(description="Backward")
state_label = widgets.Label(value=f"Iteration: {current_index} / {len(mh_trace)-1}")
output_plot = widgets.Output()
def update_plot():
with output_plot:
clear_output(wait=True)
plot_mh_trace_upto(mh_trace, current_index, global_xlims, global_ylims)
state_label.value = f"Iteration: {current_index} / {len(mh_trace)-1}"
def on_forward_clicked(b):
global current_index
if current_index < len(mh_trace) - 1:
current_index += 1
update_plot()
def on_backward_clicked(b):
global current_index
if current_index > 0:
current_index -= 1
update_plot()
# Create a Play widget.
play_widget = widgets.Play(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
# time in milliseconds between steps
interval=200,
description="Press play",
disabled=False
)
# Create a slider (if you want to display it too).
iteration_slider = widgets.IntSlider(
value=0,
min=0,
max=len(mh_trace)-1,
step=1,
description='Iteration:',
continuous_update=True
)
# Link the play widget to the slider so that they move together.
widgets.jslink((play_widget, 'value'), (iteration_slider, 'value'))
# Now, update your update_plot function to observe changes to the slider:
def on_slider_change(change):
global current_index
current_index = change['new']
update_plot()
iteration_slider.observe(on_slider_change, names='value')
button_forward.on_click(on_forward_clicked)
button_backward.on_click(on_backward_clicked)
# Finally, display the play widget (and slider) alongside your forward/backward buttons.
controls = widgets.HBox([button_backward, play_widget, iteration_slider, button_forward, state_label])
display(controls, output_plot)
update_plot()
Let’s see the approximated posterior, i.e., the proportion of steps that the chain spent in each formula:
sentences = ["".join(tree_to_sentence(x['current_tree'])) for x in mh_trace]
c = Counter()
for s in sentences:
c[s] += 1
c.most_common(10)
[('compose(map_(minus(2)),filter_(gt(1)))', 2448),
('compose(map_(minus(2)),filter_(gt(2)))', 2104),
('truncate(1)', 114),
('compose(map_(minus(2)),filter_(not_(not_(gt(2)))))', 92),
('compose(map_(minus(2)),filter_(not_(not_(gt(1)))))', 72),
('compose(filter_(gt(3)),reverse)', 23),
('compose(filter_(gt(3)),truncate(4))', 13),
('compose(reverse,filter_(even))', 10),
('compose(truncate(2),reverse)', 10),
('compose(filter_(gt(3)),map_(times(2)))', 10)]
Sequential Monte Carlo#
Next, we can have a look at the SMC algorithm, where a cloud of particles explores a space with a sequence of mutations, reweighting, and sampling steps. Let’s first define our data again to keep things local:
expr_str = "compose(map_(minus(2)), filter_(gt(2)))"
transformation_fn = eval(expr_str, lt_eval_dict)
inputs = [
[2, 3, 5, 6, 1],
[6, 2, 3, 1],
[9, 1, 8, 3, 5, 2]
]
data = [
(inp, transformation_fn(inp))
for inp in inputs
]
Next, we run the algorithm:
smc_states = smc_sampler(
ltgrammar,
data,
'T',
lt_eval_dict,
num_particles=1000,
num_iterations=500,
resample_prop=0.5,
likelihoodf=compute_likelihood_lt
)
Let’s see how much many particles are on each formula at the final step:
c = Counter()
for s in smc_states[-1]['particles']:
c[''.join(s['sentence'])] += 1
c.most_common(10)
[('compose(filter_(even),map_(minus(2)))', 61),
('compose(filter_(gt(3)),map_(minus(2)))', 38),
('compose(filter_(gt(4)),map_(minus(2)))', 29),
('compose(reverse,map_(minus(2)))', 25),
('compose(filter_(gt(5)),map_(minus(2)))', 24),
('compose(filter_(gt(1)),map_(minus(2)))', 23),
('reverse', 22),
('compose(sort,map_(minus(2)))', 20),
('compose(filter_(gt(2)),map_(minus(2)))', 19),
('compose(filter_(gt(4)),map_(minus(3)))', 18)]
And let’s visualize it!
global_lims, global_lims = compute_global_limits_smc(smc_states, ltgrammar, data, lt_eval_dict, likelihoodf=compute_likelihood_lt)
# Create a slider to select the state index.
state_slider = widgets.IntSlider(value=0, min=0, max=len(smc_states)-1, step=1, description='State:')
# Create a Play widget (interval in ms)
play_widget = widgets.Play(value=0, min=0, max=len(smc_states)-1, step=1, interval=300, description="Press play")
# Link the Play widget to the slider.
widgets.jslink((play_widget, 'value'), (state_slider, 'value'))
# Create Forward/Backward buttons.
button_forward = widgets.Button(description="Forward")
button_backward = widgets.Button(description="Backward")
state_label = widgets.Label(value=f"State: 0 / {len(smc_states)-1}")
output_plot = widgets.Output()
# Global index
current_state_index = 0
def update_plot(index):
with output_plot:
clear_output(wait=True)
state = smc_states[index]
plot_state_2d(state, ltgrammar, data, global_xlims, global_ylims, lt_eval_dict, compute_likelihood_lt)
state_label.value = f"State: {index} / {len(smc_states)-1}"
# Update when slider value changes.
def on_slider_change(change):
global current_state_index
current_state_index = change['new']
update_plot(current_state_index)
state_slider.observe(on_slider_change, names='value')
def on_forward_clicked(b):
global current_state_index
if current_state_index < len(smc_states)-1:
current_state_index += 1
state_slider.value = current_state_index
def on_backward_clicked(b):
global current_state_index
if current_state_index > 0:
current_state_index -= 1
state_slider.value = current_state_index
button_forward.on_click(on_forward_clicked)
button_backward.on_click(on_backward_clicked)
controls = widgets.HBox([button_backward, play_widget, state_slider, button_forward, state_label])
display(controls, output_plot)
update_plot(0)