diff --git a/test_al.py b/sandbox/run_al_fractions.py
similarity index 100%
rename from test_al.py
rename to sandbox/run_al_fractions.py
diff --git a/test_ppo.py b/sandbox/run_ppo_fractions.py
similarity index 84%
rename from test_ppo.py
rename to sandbox/run_ppo_fractions.py
index 402ab8ab8961d48a180c7401293e16cd367f6b43..ab52f3cf1173fa1d4464b7534751e893e2b46246 100644
--- a/test_ppo.py
+++ b/sandbox/run_ppo_fractions.py
@@ -1,7 +1,8 @@
import gym
-from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
+from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2
+from stable_baselines import SAC
import tutorenvs
import numpy as np
@@ -10,14 +11,12 @@ if __name__ == "__main__":
# multiprocess environment
env = make_vec_env('FractionArith-v1', n_envs=8)
-
- model = PPO2(MlpPolicy, env, verbose=0,
+ model = PPO2(MlpPolicy, env, verbose=1,
gamma=0.5,
tensorboard_log="./ppo_FractionArith-v0/")
-
while True:
- model.learn(total_timesteps=999999999)
+ model.learn(total_timesteps=9999999999)
# model.save("ppo2_cartpole")
# del model # remove to demonstrate saving and loading
@@ -31,5 +30,5 @@ if __name__ == "__main__":
# action, _states = model.predict(obs)
# obs, rewards, dones, info = env.step(action)
# rwd += np.sum(rewards)
- # # env.render()
+ # env.render()
# print(rwd)
diff --git a/sandbox/run_ppo_multicolumn.py b/sandbox/run_ppo_multicolumn.py
new file mode 100644
index 0000000000000000000000000000000000000000..31cea306bd386d789965444c8f3181d010e38116
--- /dev/null
+++ b/sandbox/run_ppo_multicolumn.py
@@ -0,0 +1,35 @@
+import gym
+from stable_baselines.common import make_vec_env
+from stable_baselines.common.policies import MlpPolicy
+from stable_baselines import PPO2
+from stable_baselines import SAC
+import tutorenvs
+import numpy as np
+
+
+if __name__ == "__main__":
+
+ # multiprocess environment
+ env = make_vec_env('MultiColumnArith-v0', n_envs=8)
+ model = PPO2(MlpPolicy, env, verbose=1,
+ gamma=0.5,
+ policy_kwargs={'net_arch': [65, 65, {'vf': [65], 'pi': [65]}]},
+ tensorboard_log="./ppo_MultiColumnArith-v0/")
+
+ while True:
+ model.learn(total_timesteps=9999999999)
+ # model.save("ppo2_cartpole")
+
+ # del model # remove to demonstrate saving and loading
+
+ # model = PPO2.load("ppo2_cartpole")
+
+ # Enjoy trained agent
+ # obs = env.reset()
+ # rwd = 0
+ # for _ in range(100):
+ # action, _states = model.predict(obs)
+ # obs, rewards, dones, info = env.step(action)
+ # rwd += np.sum(rewards)
+ # env.render()
+ # print(rwd)
diff --git a/tutorenvs/__init__.py b/tutorenvs/__init__.py
index a5da67f150c672ad1dfa2949fe9896b3d136c9df..e89834c773ba3a1e01fd593df8a7900127b537c5 100644
--- a/tutorenvs/__init__.py
+++ b/tutorenvs/__init__.py
@@ -1,6 +1,8 @@
from gym.envs.registration import register
from tutorenvs.fractions import FractionArithDigitsEnv
from tutorenvs.fractions import FractionArithOppEnv
+from tutorenvs.multicolumn import MultiColumnAdditionOppEnv
+from tutorenvs.multicolumn import MultiColumnAdditionDigitsEnv
register(
id='FractionArith-v0',
@@ -11,3 +13,13 @@ register(
id='FractionArith-v1',
entry_point='tutorenvs:FractionArithOppEnv',
)
+
+register(
+ id='MultiColumnArith-v0',
+ entry_point='tutorenvs:MultiColumnAdditionDigitsEnv',
+)
+
+register(
+ id='MultiColumnArith-v1',
+ entry_point='tutorenvs:MultiColumnAdditionOppEnv',
+)
diff --git a/tutorenvs/multicolumn.py b/tutorenvs/multicolumn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9649b5a9c969301054b350a3c3973c4be5a96da3
--- /dev/null
+++ b/tutorenvs/multicolumn.py
@@ -0,0 +1,475 @@
+from random import randint
+from random import choice
+from pprint import pprint
+
+import gym
+from gym import error, spaces, utils
+from gym.utils import seeding
+from sklearn.feature_extraction import FeatureHasher
+from sklearn.feature_extraction import DictVectorizer
+import numpy as np
+
+from tutorenvs.utils import BaseOppEnv
+
+def custom_add(a, b):
+ if a == '':
+ a = '0'
+ if b == '':
+ b = '0'
+ return str(int(a) + int(b))
+
+class MultiColumnAdditionSymbolic:
+
+ def __init__(self):
+ """
+ Creates a state and sets a random problem.
+ """
+ self.set_random_problem()
+ # self.reset("", "", "", "", "")
+
+ def reset(self, upper, lower):
+ """
+ Sets the state to a new fraction arithmetic problem as specified by the
+ provided arguments.
+ """
+ correct_answer = str(int(upper) + int(lower))
+ self.correct_thousands = ""
+ self.correct_hundreds = ""
+ self.correct_tens = ""
+ self.correct_ones = ""
+
+ if len(correct_answer) == 4:
+ self.correct_thousands = correct_answer[0]
+ self.correct_hundreds = correct_answer[1]
+ self.correct_tens = correct_answer[2]
+ self.correct_ones = correct_answer[3]
+ elif len(correct_answer) == 3:
+ self.correct_hundreds = correct_answer[0]
+ self.correct_tens = correct_answer[1]
+ self.correct_ones = correct_answer[2]
+ elif len(correct_answer) == 2:
+ self.correct_tens = correct_answer[0]
+ self.correct_ones = correct_answer[1]
+ elif len(correct_answer) == 1:
+ self.correct_ones = correct_answer[0]
+ else:
+ raise ValueError("Something is wrong, correct answer should have 1-4 digits")
+
+ upper_hundreds = ''
+ upper_tens = ''
+ upper_ones = ''
+
+ if len(upper) == 3:
+ upper_hundreds = upper[0]
+ upper_tens = upper[1]
+ upper_ones = upper[2]
+ if len(upper) == 2:
+ upper_tens = upper[0]
+ upper_ones = upper[1]
+ if len(upper) == 1:
+ upper_ones = upper[0]
+
+ lower_hundreds = ''
+ lower_tens = ''
+ lower_ones = ''
+
+ if len(lower) == 3:
+ lower_hundreds = lower[0]
+ lower_tens = lower[1]
+ lower_ones = lower[2]
+ if len(lower) == 2:
+ lower_tens = lower[0]
+ lower_ones = lower[1]
+ if len(lower) == 1:
+ lower_ones = lower[0]
+
+ self.steps = 0
+ self.state = {
+ 'hundreds_carry': '',
+ 'tens_carry': '',
+ 'ones_carry': '',
+ 'upper_hundreds': upper_hundreds,
+ 'upper_tens': upper_tens,
+ 'upper_ones': upper_ones,
+ 'lower_hundreds': lower_hundreds,
+ 'lower_tens': lower_tens,
+ 'lower_ones': lower_ones,
+ 'operator': '+',
+ 'answer_thousands': '',
+ 'answer_hundreds': '',
+ 'answer_tens': '',
+ 'answer_ones': ''
+ }
+
+ def get_possible_selections(self):
+ return ['hundreds_carry',
+ 'tens_carry',
+ 'ones_carry',
+ 'answer_thousands',
+ 'answer_hundreds',
+ 'answer_tens',
+ 'answer_ones',
+ 'done']
+
+ def get_possible_args(self):
+ return [
+ 'hundreds_carry',
+ 'tens_carry',
+ 'ones_carry',
+ 'upper_hundreds',
+ 'upper_tens',
+ 'upper_ones',
+ 'lower_hundreds',
+ 'lower_tens',
+ 'lower_ones',
+ 'answer_thousands',
+ 'answer_hundreds',
+ 'answer_tens',
+ 'answer_ones',
+ ]
+
+ def render(self):
+ state = {attr: " " if self.state[attr] == '' else self.state[attr] for
+ attr in self.state}
+
+ output = " %s%s%s \n %s%s%s\n+ %s%s%s\n-----\n %s%s%s%s\n" % (
+ state["hundreds_carry"],
+ state["tens_carry"],
+ state["ones_carry"],
+ state["upper_hundreds"],
+ state["upper_tens"],
+ state["upper_ones"],
+ state["lower_hundreds"],
+ state["lower_tens"],
+ state["lower_ones"],
+ state["answer_thousands"],
+ state["answer_hundreds"],
+ state["answer_tens"],
+ state["answer_ones"],
+ )
+
+ print("------------------------------------------------------")
+ print(output)
+ print("------------------------------------------------------")
+ print()
+
+ def get_state(self):
+ """
+ Returns the current state as a dict.
+ """
+ state_output = {attr:
+ {'id': attr, 'value': self.state[attr],
+ 'type': 'TextField',
+ 'contentEditable': self.state[attr] == "",
+ 'dom_class': 'CTATTable--cell',
+ 'above': '',
+ 'below': '',
+ 'to_left': '',
+ 'to_right': ''
+ }
+ for attr in self.state}
+ state_output['done'] = {
+ 'id': 'done',
+ 'type': 'Component',
+ 'dom_class': 'CTATDoneButton',
+ 'above': '',
+ 'below': '',
+ 'to_left': '',
+ 'to_right': ''
+ }
+
+ return state_output
+
+ def set_random_problem(self):
+ upper = str(randint(1,999))
+ lower = str(randint(1,999))
+ self.reset(upper=upper, lower=lower)
+
+ def apply_sai(self, selection, action, inputs):
+ """
+ Give a SAI, it applies it. This method returns feedback (i.e., -1 or 1).
+ """
+ self.steps += 1
+ reward = self.evaluate_sai(selection, action, inputs)
+
+ if reward == -1.0:
+ return reward
+
+ if selection == "done":
+ print("DONE! Only took %i steps." % self.steps)
+ self.render()
+ print()
+ # pprint(self.state)
+ self.set_random_problem()
+
+ else:
+ self.state[selection] = inputs['value']
+
+ return reward
+
+
+
+ def evaluate_sai(self, selection, action, inputs):
+ """
+ Given a SAI, returns whether it is correct or incorrect.
+ """
+ # done step
+ if selection == "done":
+
+ if action != "ButtonPressed":
+ return -1.0
+
+ if (self.state['answer_thousands'] == self.correct_thousands and
+ self.state['answer_hundreds'] == self.correct_hundreds and
+ self.state['answer_tens'] == self.correct_tens and
+ self.state['answer_ones'] == self.correct_ones):
+ return 1.0
+ else:
+ return -1.0
+
+ # we can only edit selections that are editable
+ if self.state[selection] != "":
+ return -1.0
+
+ if (selection == "answer_ones" and
+ inputs['value'] == self.correct_ones):
+ return 1.0
+
+ if (selection == "ones_carry" and
+ len(custom_add(self.state['upper_ones'],
+ self.state['lower_ones'])) == 2 and
+ inputs['value'] == custom_add(self.state['upper_ones'],
+ self.state['lower_ones'])[0]):
+ return 1.0
+
+ if (selection == "answer_tens" and self.state['answer_ones'] != "" and
+ (self.state['ones_carry'] != "" or
+ len(custom_add(self.state['upper_ones'],
+ self.state['lower_ones'])) == 1) and
+ inputs['value'] == self.correct_tens):
+ return 1.0
+
+ if (selection == "tens_carry" and
+ self.state['answer_ones'] != "" and
+ (self.state['ones_carry'] != "" or
+ len(custom_add(self.state['upper_ones'],
+ self.state['lower_ones'])) == 1)):
+
+ if (self.state['ones_carry'] != ""):
+ tens_sum = custom_add(custom_add(self.state['upper_tens'],
+ self.state['lower_tens']), self.state['ones_carry'])
+ else:
+ tens_sum = custom_add(self.state['upper_tens'],
+ self.state['lower_tens'])
+
+ if len(tens_sum) == 2:
+ if inputs['value'] == tens_sum[0]:
+ return 1.0
+
+ if (selection == "answer_hundreds" and
+ self.state['answer_tens'] != "" and
+ (self.state['tens_carry'] != "" or
+ len(custom_add(self.state['upper_tens'],
+ self.state['lower_tens'])) == 1) and
+ inputs['value'] == self.correct_hundreds):
+ return 1.0
+
+ if (selection == "hundreds_carry" and
+ self.state['answer_tens'] != "" and
+ (self.state['tens_carry'] != "" or
+ len(custom_add(self.state['upper_tens'],
+ self.state['lower_tens'])) == 1)):
+
+ if (self.state['tens_carry'] != ""):
+ hundreds_sum = custom_add(custom_add(
+ self.state['upper_hundreds'],
+ self.state['lower_hundreds']),
+ self.state['tens_carry'])
+ else:
+ hundreds_sum = custom_add(
+ self.state['upper_hundreds'],
+ self.state['lower_hundreds'])
+
+ if len(hundreds_sum) == 2:
+ if inputs['value'] == hundreds_sum[0]:
+ return 1.0
+
+ if (selection == "answer_thousands" and
+ self.state['answer_hundreds'] != "" and
+ self.state['hundreds_carry'] != "" and
+ inputs['value'] == self.correct_thousands):
+ return 1.0
+
+ return -1.0
+
+ # TODO still need to rewrite for multi column arith
+ def request_demo(self):
+ """
+ Returns a correct next-step SAI
+ """
+ if (self.state['initial_operator'] == '+' and
+ self.state['initial_denom_left'] == self.state['initial_denom_right']):
+ if self.state['answer_num'] == "":
+ return ('answer_num', "UpdateField",
+ {'value': str(int(self.state['initial_num_left']) +
+ int(self.state['initial_num_right']))})
+
+ if self.state['answer_denom'] == "":
+ return ('answer_denom', "UpdateField",
+ {'value': self.state['initial_denom_left']})
+
+ return ('done', "ButtonPressed", {'value': -1})
+
+ if (self.state['initial_operator'] == "+" and
+ self.state['initial_denom_left'] != self.state['initial_denom_right']):
+
+ if self.state['check_convert'] == "":
+ return ('check_convert', 'UpdateField', {"value": 'x'})
+
+ if self.state['convert_denom_left'] == "":
+ return ('convert_denom_left', "UpdateField",
+ {'value': str(int(self.state['initial_denom_left']) *
+ int(self.state['initial_denom_right']))})
+
+ if self.state['convert_num_left'] == "":
+ return ('convert_num_left', "UpdateField",
+ {'value': str(int(self.state['initial_num_left']) *
+ int(self.state['initial_denom_right']))})
+
+ if self.state['convert_denom_right'] == "":
+ return ('convert_denom_right', "UpdateField",
+ {'value': str(int(self.state['initial_denom_left']) *
+ int(self.state['initial_denom_right']))})
+
+ if self.state['convert_num_right'] == "":
+ return ('convert_num_right', "UpdateField",
+ {'value': str(int(self.state['initial_denom_left']) *
+ int(self.state['initial_num_right']))})
+
+ if self.state['answer_num'] == "":
+ return ('answer_num', "UpdateField",
+ {'value': str(int(self.state['convert_num_left']) +
+ int(self.state['convert_num_right']))})
+
+ if self.state['answer_denom'] == "":
+ return ('answer_denom', "UpdateField",
+ {'value': self.state['convert_denom_right']})
+
+ return ('done', "ButtonPressed", {'value': -1})
+
+ if (self.state['initial_operator'] == "*"):
+ if self.state['answer_num'] == "":
+ return ('answer_num', "UpdateField",
+ {'value': str(int(self.state['initial_num_left']) *
+ int(self.state['initial_num_right']))})
+
+ if self.state['answer_denom'] == "":
+ return ('answer_denom', "UpdateField",
+ {'value': str(int(self.state['initial_denom_left']) *
+ int(self.state['initial_denom_right']))})
+
+ return ('done', "ButtonPressed", {'value': -1})
+
+ raise Exception("request demo - logic missing")
+
+
+class MultiColumnAdditionOppEnv(BaseOppEnv):
+
+ def __init__(self):
+ super().__init__(MultiColumnAdditionSymbolic, max_depth=2)
+
+ def get_rl_operators(self):
+ return [
+ 'copy',
+ 'add',
+ 'mod10',
+ 'div10',
+ ]
+
+class MultiColumnAdditionDigitsEnv(gym.Env):
+ metadata = {'render.modes': ['human']}
+
+ def get_dv_training(self):
+ empty = {attr: '' for attr in self.tutor.state if attr != 'operator'}
+
+ training_data = [empty]
+
+ for i in range(1, 10):
+ s = {attr: str(i) for attr in self.tutor.state if attr != 'operator'}
+ training_data.append(s)
+
+ return training_data
+
+ def get_rl_state(self):
+ # self.state = {
+ # 'hundreds_carry': '',
+ # 'tens_carry': '',
+ # 'ones_carry': '',
+ # 'upper_hundreds': upper_hundreds,
+ # 'upper_tens': upper_tens,
+ # 'upper_ones': upper_ones,
+ # 'lower_hundreds': lower_hundreds,
+ # 'lower_tens': lower_tens,
+ # 'lower_ones': lower_ones,
+ # 'operator': '+',
+ # 'answer_thousands': '',
+ # 'answer_hundreds': '',
+ # 'answer_tens': '',
+ # 'answer_ones': ''
+ # }
+ return self.tutor.state
+
+ def __init__(self):
+ self.tutor = MultiColumnAdditionSymbolic()
+ n_selections = len(self.tutor.get_possible_selections())
+ self.dv = DictVectorizer()
+ transformed_training = self.dv.fit_transform(self.get_dv_training())
+ n_features = transformed_training.shape[1]
+
+ self.observation_space = spaces.Box(low=0.0,
+ high=1.0, shape=(1, n_features), dtype=np.float32)
+ self.action_space = spaces.MultiDiscrete([n_selections, 10])
+
+ def step(self, action):
+ s, a, i = self.decode(action)
+ # print(s, a, i)
+ # print()
+ reward = self.tutor.apply_sai(s, a, i)
+ # print(reward)
+
+ state = self.get_rl_state()
+ # pprint(state)
+ obs = self.dv.transform([state])[0].toarray()
+ done = (s == 'done' and reward == 1.0)
+ info = {}
+
+ return obs, reward, done, info
+
+ def decode(self, action):
+ # print(action)
+ s = self.tutor.get_possible_selections()[action[0]]
+
+ if s == "done":
+ a = "ButtonPressed"
+ else:
+ a = "UpdateField"
+
+ if s == "done":
+ v = -1
+ if s == "check_convert":
+ v = "x"
+ else:
+ v = action[1]
+
+ i = {'value': str(v)}
+
+ return s, a, i
+
+ def reset(self):
+ self.tutor.set_random_problem()
+ state = self.get_rl_state()
+ obs = self.dv.transform([state])[0].toarray()
+ return obs
+
+ def render(self, mode='human', close=False):
+ self.tutor.render()
diff --git a/tutorenvs/utils.py b/tutorenvs/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..42479be76b9a1c3231c858843e5303b84f196e09
--- /dev/null
+++ b/tutorenvs/utils.py
@@ -0,0 +1,227 @@
+from pprint import pprint
+
+import gym
+from gym import error, spaces, utils
+from sklearn.feature_extraction import DictVectorizer
+import numpy as np
+
+class BaseOppEnv(gym.Env):
+ metadata = {'render.modes': ['human']}
+
+ def __init__(self, tutor_class, max_depth=1):
+ print('building env')
+ self.tutor = tutor_class()
+
+ self.max_depth = max_depth
+ self.internal_memory = {}
+
+ self.possible_attr = set(self.tutor.get_possible_args())
+ for _ in range(self.max_depth):
+ new = set()
+ for opp in self.get_rl_operators():
+ for a1 in self.possible_attr:
+ for a2 in self.possible_attr:
+ new.add((opp, a1, a2))
+ self.possible_attr = self.possible_attr.union(new)
+ print('# features = %i' % len(self.possible_attr))
+
+ self.possible_args = list(set([attr[1] if isinstance(attr, tuple) else
+ attr for attr in self.possible_attr]))
+ print('# args = %i' % len(self.possible_args))
+
+ # one additional option to save result internally
+ n_selections = len(self.tutor.get_possible_selections()) + 1
+ print('getting rl state')
+ n_features = len(self.get_rl_state())
+ print('done getting rl state')
+ n_operators = len(self.get_rl_operators())
+ n_args = len(self.possible_args)
+ self.dv = DictVectorizer()
+ self.dv.fit([self.get_rl_state()])
+
+ self.observation_space = spaces.Box(low=0.0,
+ high=1.0, shape=(1, n_features), dtype=np.float32)
+ self.action_space = spaces.MultiDiscrete([n_selections, n_operators,
+ n_args, n_args])
+ print('done')
+
+ def get_rl_operators(self):
+ return [
+ 'copy',
+ 'add',
+ 'multiply',
+ 'mod10',
+ 'div10',
+ ]
+
+ def get_rl_state(self):
+ # self.state = {
+ # 'hundreds_carry': '',
+ # 'tens_carry': '',
+ # 'ones_carry': '',
+ # 'upper_hundreds': upper_hundreds,
+ # 'upper_tens': upper_tens,
+ # 'upper_ones': upper_ones,
+ # 'lower_hundreds': lower_hundreds,
+ # 'lower_tens': lower_tens,
+ # 'lower_ones': lower_ones,
+ # 'operator': '+',
+ # 'answer_thousands': '',
+ # 'answer_hundreds': '',
+ # 'answer_tens': '',
+ # 'answer_ones': ''
+ # }
+
+ state = {}
+ for attr in self.tutor.state:
+
+ # TODO need generic way to handle this.
+ if attr == "operator":
+ continue
+
+ # just whether or not there is a value
+ state[attr] = self.tutor.state[attr] != ""
+
+ # if its in internal memory, then return true, else false.
+ for possible_attr in self.possible_attr:
+ state[possible_attr] = possible_attr in self.internal_memory
+
+ print('done with base attributes in state')
+ print('# of base attributes = %i' % len(state))
+
+ # relations (equality, >10)
+ new_relations = {}
+
+ for attr in state:
+ attr_val = None
+ if attr in self.tutor.state:
+ attr_val = self.tutor.state[attr]
+ elif attr in self.internal_memory:
+ attr_val = self.internal_memory[attr]
+ else:
+ attr_val = ''
+
+ # greater than 9
+ try:
+ new_relations['greater_than_9(%s)' % str(attr)] = float(attr_val) > 9
+ except Exception:
+ new_relations['greater_than_9(%s)' % str(attr)] = False
+
+ # # equality
+ # for attr2 in state:
+ # if str(attr) >= str(attr2):
+ # continue
+
+ # attr2_val = None
+ # if attr2 in self.tutor.state:
+ # attr2_val = self.tutor.state[attr2]
+ # elif attr2 in self.internal_memory:
+ # attr2_val = self.internal_memory[attr2]
+ # else:
+ # attr2_val = ''
+ # new_relations['eq(%s,%s)' % (attr, attr2)] = attr_val == attr2_val
+
+ print('done with creating new relations')
+ print('# of new relations = %i' % len(new_relations))
+
+ for attr in new_relations:
+ state[attr] = new_relations[attr]
+
+ # convert all attributes to strings
+ return {str(attr): state[attr] for attr in state}
+
+ def step(self, action):
+ try:
+ s, a, i = self.decode(action)
+
+ if isinstance(s, tuple):
+ if s in self.internal_memory or i == '':
+ reward = -1
+ else:
+ self.internal_memory[s] = i
+ reward = -0.01
+ else:
+ reward = self.tutor.apply_sai(s, a, i)
+ done = (s == 'done' and reward == 1.0)
+ except ValueError:
+ reward = -1
+ done = False
+
+ # print(s, a, i)
+ # print()
+ # print(reward)
+
+ state = self.get_rl_state()
+ # pprint(state)
+ obs = self.dv.transform([state])[0].toarray()
+ info = {}
+
+ return obs, reward, done, info
+
+
+ def apply_rl_op(self, op, arg1, arg2):
+ a1 = None
+ a2 = None
+
+ if arg1 in self.tutor.state:
+ a1 = self.tutor.state[arg1]
+ elif arg1 in self.internal_memory:
+ a1 = self.internal_memory[arg1]
+ else:
+ raise ValueError('Element not in memory')
+
+ if arg2 in self.tutor.state:
+ a2 = self.tutor.state[arg2]
+ elif arg2 in self.internal_memory:
+ a2 = self.internal_memory[arg2]
+ else:
+ raise ValueError('Element not in memory')
+
+ if op == "copy":
+ return a1
+ elif op == "add":
+ return str(int(a1) + int(a2))
+ elif op == "multiply":
+ return str(int(a1) * int(a2))
+ elif op == "mod10":
+ return str(int(a1) % 10)
+ elif op == "div10":
+ return str(int(a1) // 10)
+
+ def decode(self, action):
+ # print(action)
+
+ op = self.get_rl_operators()[action[1]]
+ arg1 = self.possible_args[action[2]]
+ arg2 = self.possible_args[action[3]]
+
+ if action[0] == len(self.tutor.get_possible_selections()):
+ s = (opp, arg1, arg2)
+ else:
+ s = self.tutor.get_possible_selections()[action[0]]
+
+ if s == "done":
+ a = "ButtonPressed"
+ else:
+ a = "UpdateField"
+
+ if s == "done":
+ v = -1
+ if s == "check_convert":
+ v = "x"
+ else:
+ v = self.apply_rl_op(op, arg1, arg2)
+
+ i = {'value': str(v)}
+
+ return s, a, i
+
+ def reset(self):
+ self.tutor.set_random_problem()
+ state = self.get_rl_state()
+ self.internal_memory = {}
+ obs = self.dv.transform([state])[0].toarray()
+ return obs
+
+ def render(self, mode='human', close=False):
+ self.tutor.render()