From a2e68a0c74bc215cfb46bbc16fdbd8535405bf69 Mon Sep 17 00:00:00 2001
From: Chris MacLellan <2348-cm3786@users.noreply.gitlab.cci.drexel.edu>
Date: Mon, 21 Dec 2020 11:29:11 -0500
Subject: [PATCH] working version of fractions PPO model and decision tree
 model

---
 sandbox/{ => fractions}/run_al_fractions.py   |   0
 sandbox/{ => fractions}/run_ppo_fractions.py  |   0
 sandbox/{ => multicolumn}/run_cobweb_multi.py |   0
 .../run_decision_tree_multi-v1.py             |   0
 .../run_dual_decision_tree_multi.py           |  15 +-
 sandbox/multicolumn/run_dual_ppo.py           | 147 ++++++++++++++++++
 sandbox/{ => multicolumn}/run_ppo_multi-v0.py |  24 ++-
 sandbox/{ => multicolumn}/run_ppo_multi-v1.py |   0
 sandbox/{ => multicolumn}/run_ppo_multi-v2.py |   0
 sandbox/{ => multicolumn}/run_ppo_multi-v3.py |   0
 .../run_single_decision_tree_multi.py         |  39 +++--
 sandbox/multicolumn/run_single_dqn.py         |  38 +++++
 tutorenvs/__init__.py                         |  12 +-
 tutorenvs/fractions.py                        |   7 +-
 tutorenvs/multicolumn.py                      |  78 ++++------
 tutorenvs/utils.py                            |  62 +++++++-
 16 files changed, 333 insertions(+), 89 deletions(-)
 rename sandbox/{ => fractions}/run_al_fractions.py (100%)
 rename sandbox/{ => fractions}/run_ppo_fractions.py (100%)
 rename sandbox/{ => multicolumn}/run_cobweb_multi.py (100%)
 rename sandbox/{ => multicolumn}/run_decision_tree_multi-v1.py (100%)
 rename sandbox/{ => multicolumn}/run_dual_decision_tree_multi.py (92%)
 create mode 100644 sandbox/multicolumn/run_dual_ppo.py
 rename sandbox/{ => multicolumn}/run_ppo_multi-v0.py (57%)
 rename sandbox/{ => multicolumn}/run_ppo_multi-v1.py (100%)
 rename sandbox/{ => multicolumn}/run_ppo_multi-v2.py (100%)
 rename sandbox/{ => multicolumn}/run_ppo_multi-v3.py (100%)
 rename sandbox/{ => multicolumn}/run_single_decision_tree_multi.py (72%)
 create mode 100644 sandbox/multicolumn/run_single_dqn.py

diff --git a/sandbox/run_al_fractions.py b/sandbox/fractions/run_al_fractions.py
similarity index 100%
rename from sandbox/run_al_fractions.py
rename to sandbox/fractions/run_al_fractions.py
diff --git a/sandbox/run_ppo_fractions.py b/sandbox/fractions/run_ppo_fractions.py
similarity index 100%
rename from sandbox/run_ppo_fractions.py
rename to sandbox/fractions/run_ppo_fractions.py
diff --git a/sandbox/run_cobweb_multi.py b/sandbox/multicolumn/run_cobweb_multi.py
similarity index 100%
rename from sandbox/run_cobweb_multi.py
rename to sandbox/multicolumn/run_cobweb_multi.py
diff --git a/sandbox/run_decision_tree_multi-v1.py b/sandbox/multicolumn/run_decision_tree_multi-v1.py
similarity index 100%
rename from sandbox/run_decision_tree_multi-v1.py
rename to sandbox/multicolumn/run_decision_tree_multi-v1.py
diff --git a/sandbox/run_dual_decision_tree_multi.py b/sandbox/multicolumn/run_dual_decision_tree_multi.py
similarity index 92%
rename from sandbox/run_dual_decision_tree_multi.py
rename to sandbox/multicolumn/run_dual_decision_tree_multi.py
index 21e5fed..0ba9369 100644
--- a/sandbox/run_dual_decision_tree_multi.py
+++ b/sandbox/multicolumn/run_dual_decision_tree_multi.py
@@ -30,7 +30,10 @@ def train_tree(n=10, logger=None):
     env = MultiColumnAdditionSymbolic(logger=logger)
 
     p = 0
+    hints = 0
+
     while p < n:
+
         # make a copy of the state
         state = {a: env.state[a] for a in env.state}
         env.render()
@@ -48,15 +51,17 @@ def train_tree(n=10, logger=None):
             sai = (sel, act, inp)
 
         if sai is None:
-            print('hint')
+            hints += 1
+            # print('hint')
             sai = env.request_demo()
             sai = (sai[0], sai[1], sai[2]['value'])
 
         reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]})
-        print('reward', reward)
+        # print('reward', reward)
 
         if reward < 0:
-            print('hint')
+            hints += 1
+            # print('hint')
             sai = env.request_demo()
             sai = (sai[0], sai[1], sai[2]['value'])
             reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]})
@@ -83,6 +88,8 @@ def train_tree(n=10, logger=None):
 
         if sai[0] == "done" and reward == 1.0:
             print("Problem %s of %s" % (p, n))
+            print("# of hints = {}".format(hints))
+            hints = 0
             p += 1
 
     return selection_tree, input_tree
@@ -91,7 +98,7 @@ if __name__ == "__main__":
 
     logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field'])
     for _ in range(1):
-        tree = train_tree(1000, logger)
+        tree = train_tree(500, logger)
     # env = MultiColumnAdditionSymbolic()
 
     # while True:
diff --git a/sandbox/multicolumn/run_dual_ppo.py b/sandbox/multicolumn/run_dual_ppo.py
new file mode 100644
index 0000000..aade75f
--- /dev/null
+++ b/sandbox/multicolumn/run_dual_ppo.py
@@ -0,0 +1,147 @@
+import numpy as np
+import gym
+from stable_baselines3 import PPO
+from stable_baselines3.ppo import MlpPolicy
+from stable_baselines3.common.env_util import make_vec_env
+
+from tutorenvs.utils import MultiDiscreteToDiscreteWrapper
+
+def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
+    """
+    Sampler for PPO hyperparams.
+    :param trial:
+    :return:
+    """
+    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128, 256, 512])
+    n_steps = trial.suggest_categorical("n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
+    gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999])
+    learning_rate = trial.suggest_loguniform("lr", 1e-5, 1)
+    lr_schedule = "constant"
+    # Uncomment to enable learning rate schedule
+    # lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant'])
+    ent_coef = trial.suggest_loguniform("ent_coef", 0.00000001, 0.1)
+    clip_range = trial.suggest_categorical("clip_range", [0.1, 0.2, 0.3, 0.4])
+    n_epochs = trial.suggest_categorical("n_epochs", [1, 5, 10, 20])
+    gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
+    max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5])
+    vf_coef = trial.suggest_uniform("vf_coef", 0, 1)
+    net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])
+    # Uncomment for gSDE (continuous actions)
+    # log_std_init = trial.suggest_uniform("log_std_init", -4, 1)
+    # Uncomment for gSDE (continuous action)
+    # sde_sample_freq = trial.suggest_categorical("sde_sample_freq", [-1, 8, 16, 32, 64, 128, 256])
+    # Orthogonal initialization
+    ortho_init = False
+    # ortho_init = trial.suggest_categorical('ortho_init', [False, True])
+    # activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu'])
+    activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])
+
+    # TODO: account when using multiple envs
+    if batch_size > n_steps:
+        batch_size = n_steps
+
+    if lr_schedule == "linear":
+        learning_rate = linear_schedule(learning_rate)
+
+    # Independent networks usually work best
+    # when not working with images
+    net_arch = {
+        "small": [dict(pi=[64, 64], vf=[64, 64])],
+        "medium": [dict(pi=[256, 256], vf=[256, 256])],
+    }[net_arch]
+
+    activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU, "elu": nn.ELU, "leaky_relu": nn.LeakyReLU}[activation_fn]
+
+    return {
+        "n_steps": n_steps,
+        "batch_size": batch_size,
+        "gamma": gamma,
+        "learning_rate": learning_rate,
+        "ent_coef": ent_coef,
+        "clip_range": clip_range,
+        "n_epochs": n_epochs,
+        "gae_lambda": gae_lambda,
+        "max_grad_norm": max_grad_norm,
+        "vf_coef": vf_coef,
+        # "sde_sample_freq": sde_sample_freq,
+        "policy_kwargs": dict(
+            # log_std_init=log_std_init,
+            net_arch=net_arch,
+            activation_fn=activation_fn,
+            ortho_init=ortho_init,
+        ),
+    }
+
+class TrialEvalCallback(EvalCallback):
+    """
+    Callback used for evaluating and reporting a trial.
+    """
+
+    def __init__(
+        self,
+        eval_env: VecEnv,
+        trial: optuna.Trial,
+        n_eval_episodes: int = 5,
+        eval_freq: int = 10000,
+        deterministic: bool = True,
+        verbose: int = 0,
+    ):
+
+        super(TrialEvalCallback, self).__init__(
+            eval_env=eval_env,
+            n_eval_episodes=n_eval_episodes,
+            eval_freq=eval_freq,
+            deterministic=deterministic,
+            verbose=verbose,
+        )
+        self.trial = trial
+        self.eval_idx = 0
+        self.is_pruned = False
+
+    def _on_step(self) -> bool:
+        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
+            super(TrialEvalCallback, self)._on_step()
+            self.eval_idx += 1
+            # report best or report current ?
+            # report num_timesteps or elasped time ?
+            self.trial.report(self.last_mean_reward, self.eval_idx)
+            # Prune trial if need
+            if self.trial.should_prune():
+                self.is_pruned = True
+                return False
+        return True
+
+if __name__ == "__main__":
+
+    # multiprocess environment
+    env = make_vec_env('MulticolumnArithSymbolic-v0', n_envs=1)
+    model = PPO(MlpPolicy, env, verbose=1,
+                # n_steps=4096,
+                learning_rate=0.0000025,
+                # learning_rate=lambda x: max(x*0.000015, 0.000005),
+                clip_range=0.05,
+                # clip_range=lambda x: max(x*0.1, 0.01),
+                # train_freq=1,
+                # exploration_fraction=0.5,
+                # exploration_initial_eps=0.45,
+                gamma=0.0,
+                # learning_starts=1,
+                policy_kwargs={'net_arch': [{'vf': [65, 65], 'pi': [65, 65]}]},
+                tensorboard_log="./tensorboard_ppo_multi/"
+                )
+            # gamma=0.1,
+            # tensorboard_log="./tensorboard/v0/")
+
+    # while True:
+        # Train
+    model.learn(total_timesteps=5000000)
+
+        # Test
+        # obs = env.reset()
+        # rwd = 0
+        # for _ in range(10000):
+        #     action, _states = model.predict(obs)
+        #     obs, rewards, dones, info = env.step(action)
+        #     rwd += np.sum(rewards)
+        #     env.render()
+        # print(rwd)
diff --git a/sandbox/run_ppo_multi-v0.py b/sandbox/multicolumn/run_ppo_multi-v0.py
similarity index 57%
rename from sandbox/run_ppo_multi-v0.py
rename to sandbox/multicolumn/run_ppo_multi-v0.py
index 42e8566..a2b084e 100644
--- a/sandbox/run_ppo_multi-v0.py
+++ b/sandbox/multicolumn/run_ppo_multi-v0.py
@@ -17,19 +17,15 @@ if __name__ == "__main__":
             tensorboard_log="./tensorboard/v0/")
 
     while True:
+        # Train
         model.learn(total_timesteps=100)
 
-        # To demonstrate saving and loading
-        # model.save("ppo2_multicolumn-v0")
-        # del model
-        # model = PPO2.load("ppo2_multicolumn-v0")
-
-        # Enjoy trained agent
-        obs = env.reset()
-        rwd = 0
-        for _ in range(10000):
-            action, _states = model.predict(obs)
-            obs, rewards, dones, info = env.step(action)
-            rwd += np.sum(rewards)
-            env.render()
-        print(rwd)
+        # Test
+        # obs = env.reset()
+        # rwd = 0
+        # for _ in range(10000):
+        #     action, _states = model.predict(obs)
+        #     obs, rewards, dones, info = env.step(action)
+        #     rwd += np.sum(rewards)
+        #     env.render()
+        # print(rwd)
diff --git a/sandbox/run_ppo_multi-v1.py b/sandbox/multicolumn/run_ppo_multi-v1.py
similarity index 100%
rename from sandbox/run_ppo_multi-v1.py
rename to sandbox/multicolumn/run_ppo_multi-v1.py
diff --git a/sandbox/run_ppo_multi-v2.py b/sandbox/multicolumn/run_ppo_multi-v2.py
similarity index 100%
rename from sandbox/run_ppo_multi-v2.py
rename to sandbox/multicolumn/run_ppo_multi-v2.py
diff --git a/sandbox/run_ppo_multi-v3.py b/sandbox/multicolumn/run_ppo_multi-v3.py
similarity index 100%
rename from sandbox/run_ppo_multi-v3.py
rename to sandbox/multicolumn/run_ppo_multi-v3.py
diff --git a/sandbox/run_single_decision_tree_multi.py b/sandbox/multicolumn/run_single_decision_tree_multi.py
similarity index 72%
rename from sandbox/run_single_decision_tree_multi.py
rename to sandbox/multicolumn/run_single_decision_tree_multi.py
index c360acd..94eaf33 100644
--- a/sandbox/run_single_decision_tree_multi.py
+++ b/sandbox/multicolumn/run_single_decision_tree_multi.py
@@ -8,26 +8,30 @@ from tutorenvs.multicolumn import MultiColumnAdditionSymbolic
 import numpy as np
 
 from sklearn.tree import DecisionTreeClassifier
-from sklearn.feature_extraction import DictVectorizer
+# from sklearn.feature_extraction import DictVectorizer
 
+from tutorenvs.utils import OnlineDictVectorizer
 from tutorenvs.utils import DataShopLogger
 
 def train_tree(n=10, logger=None):
     X = []
     y = []
-    dv = DictVectorizer()
+    dv = OnlineDictVectorizer(110)
     actions = []
     action_mapping = {}
     rev_action_mapping = {}
-    selection_tree = DecisionTreeClassifier()
-    input_tree = DecisionTreeClassifier()
+    tree = DecisionTreeClassifier()
     env = MultiColumnAdditionSymbolic(logger=logger)
 
+    hints= 0
     p = 0
+
+    Xv = None
+
     while p < n:
         # make a copy of the state
         state = {a: env.state[a] for a in env.state}
-        env.render()
+        # env.render()
 
         if rev_action_mapping == {}:
             sai = None
@@ -36,23 +40,30 @@ def train_tree(n=10, logger=None):
             sai = rev_action_mapping[tree.predict(vstate)[0]]
 
         if sai is None:
-            print('hint')
+            hints += 1
+            # print('hint')
             sai = env.request_demo()
             sai = (sai[0], sai[1], sai[2]['value'])
 
         reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]})
-        print('reward', reward)
+        # print('reward', reward)
 
         if reward < 0:
-            print('hint')
+            hints += 1
+            # print('hint')
             sai = env.request_demo()
             sai = (sai[0], sai[1], sai[2]['value'])
             reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]})
 
-        X.append(state)
+        # X.append(state)
         y.append(sai)
 
-        Xv = dv.fit_transform(X)
+        if Xv is None:
+            Xv = dv.fit_transform([state])
+        else:
+            Xv = np.concatenate((Xv, dv.fit_transform([state])))
+
+        # print('shape', Xv.shape)
         actions = set(y)
         action_mapping = {l: i for i, l in enumerate(actions)}
         rev_action_mapping = {i: l for i, l in enumerate(actions)}
@@ -61,6 +72,10 @@ def train_tree(n=10, logger=None):
         tree.fit(Xv, yv)
 
         if sai[0] == "done" and reward == 1.0:
+            print("Problem %s of %s" % (p, n))
+            print("# of hints = {}".format(hints))
+            hints = 0
+
             p += 1
 
     return tree
@@ -68,8 +83,8 @@ def train_tree(n=10, logger=None):
 if __name__ == "__main__":
 
     logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field'])
-    for _ in range(10):
-        tree = train_tree(100, logger)
+    for _ in range(1):
+        tree = train_tree(30000, logger)
     # env = MultiColumnAdditionSymbolic()
 
     # while True:
diff --git a/sandbox/multicolumn/run_single_dqn.py b/sandbox/multicolumn/run_single_dqn.py
new file mode 100644
index 0000000..9aa6f7b
--- /dev/null
+++ b/sandbox/multicolumn/run_single_dqn.py
@@ -0,0 +1,38 @@
+import numpy as np
+import gym
+from stable_baselines3 import DQN
+from stable_baselines3.dqn import MlpPolicy
+
+from tutorenvs.utils import MultiDiscreteToDiscreteWrapper
+
+if __name__ == "__main__":
+
+    # multiprocess environment
+    env = gym.make('MulticolumnArithSymbolic-v0')
+    env = MultiDiscreteToDiscreteWrapper(env)
+    model = DQN(MlpPolicy, env, verbose=1,
+                learning_rate=0.0025,
+                train_freq=1,
+                exploration_fraction=0.5,
+                exploration_initial_eps=0.45,
+                gamma=0.0,
+                learning_starts=1,
+                policy_kwargs={'net_arch': [65, 65, 65]}, # {'qf': [65], 'pi': [65]}]},
+                # tensorboard_log="./tensorboard_dqn_multi/"
+                )
+            # gamma=0.1,
+            # tensorboard_log="./tensorboard/v0/")
+
+    while True:
+        # Train
+        model.learn(total_timesteps=1000000)
+
+        # Test
+        # obs = env.reset()
+        # rwd = 0
+        # for _ in range(10000):
+        #     action, _states = model.predict(obs)
+        #     obs, rewards, dones, info = env.step(action)
+        #     rwd += np.sum(rewards)
+        #     env.render()
+        # print(rwd)
diff --git a/tutorenvs/__init__.py b/tutorenvs/__init__.py
index 491b28c..0e7c94d 100644
--- a/tutorenvs/__init__.py
+++ b/tutorenvs/__init__.py
@@ -1,7 +1,6 @@
 from gym.envs.registration import register
 from tutorenvs.fractions import FractionArithDigitsEnv
 from tutorenvs.fractions import FractionArithOppEnv
-from tutorenvs.multicolumn import MultiColumnAdditionOppEnv
 from tutorenvs.multicolumn import MultiColumnAdditionDigitsEnv
 from tutorenvs.multicolumn import MultiColumnAdditionPixelEnv
 from tutorenvs.multicolumn import MultiColumnAdditionPerceptEnv
@@ -23,21 +22,16 @@ register(
 # )
 
 register(
-    id='MultiColumnArith-v0',
-    entry_point='tutorenvs:MultiColumnAdditionOppEnv',
-)
-
-register(
-    id='MultiColumnArith-v1',
+    id='MulticolumnArithSymbolic-v0',
     entry_point='tutorenvs:MultiColumnAdditionDigitsEnv',
 )
 
 register(
-    id='MultiColumnArith-v2',
+    id='MulticolumnArithPixel-v0',
     entry_point='tutorenvs:MultiColumnAdditionPixelEnv',
 )
 
 register(
-    id='MultiColumnArith-v3',
+    id='MulticolumnArithPercept-v0',
     entry_point='tutorenvs:MultiColumnAdditionPerceptEnv',
 )
diff --git a/tutorenvs/fractions.py b/tutorenvs/fractions.py
index 793c8d7..1fc9d44 100644
--- a/tutorenvs/fractions.py
+++ b/tutorenvs/fractions.py
@@ -176,7 +176,9 @@ class FractionArithSymbolic:
             outcome = "INCORRECT"
             self.num_incorrect_steps += 1
 
-        self.logger.log_step(selection, action, inputs['value'], outcome, [self.ptype + '_' + selection])
+        self.logger.log_step(selection, action, inputs['value'], outcome,
+                             step_name=self.ptype + '_' + demo[0],
+                             kcs=[self.ptype + '_' + selection])
 
         # Render output?
         self.render()
@@ -294,7 +296,8 @@ class FractionArithSymbolic:
         demo = self.get_demo()
         feedback_text = "selection: %s, action: %s, input: %s" % (demo[0],
                 demo[1], demo[2]['value'])
-        self.logger.log_hint(feedback_text, [self.ptype + '_' + demo[0]])
+        self.logger.log_hint(feedback_text, step_name=self.ptype + '_' +
+                             demo[0], kcs=[self.ptype + '_' + demo[0]])
         self.num_hints += 1
 
         return demo
diff --git a/tutorenvs/multicolumn.py b/tutorenvs/multicolumn.py
index c5e70d4..1d34d39 100644
--- a/tutorenvs/multicolumn.py
+++ b/tutorenvs/multicolumn.py
@@ -11,8 +11,9 @@ from sklearn.feature_extraction import DictVectorizer
 import numpy as np
 from PIL import Image, ImageDraw
 
-from tutorenvs.utils import BaseOppEnv
+from tutorenvs.utils import OnlineDictVectorizer
 from tutorenvs.utils import DataShopLogger
+from tutorenvs.utils import StubLogger
 
 def custom_add(a, b):
     if a == '':
@@ -29,7 +30,8 @@ class MultiColumnAdditionSymbolic:
         """
         if logger is None:
             print("CREATING LOGGER")
-            self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field'])
+            # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field'])
+            self.logger = StubLogger()
         else:
             self.logger = logger
         self.logger.set_student()
@@ -234,10 +236,12 @@ class MultiColumnAdditionSymbolic:
         return state_output
 
     def set_random_problem(self):
-        # upper = str(randint(1,999))
-        # lower = str(randint(1,999))
-        upper = str(randint(1,9))
-        lower = str(randint(1,9))
+        upper = str(randint(1,999))
+        lower = str(randint(1,999))
+        # upper = str(randint(1,99))
+        # lower = str(randint(1,99))
+        # upper = str(randint(1,9))
+        # lower = str(randint(1,9))
         self.reset(upper=upper, lower=lower)
         self.logger.set_problem("%s_%s" % (upper, lower))
 
@@ -254,13 +258,13 @@ class MultiColumnAdditionSymbolic:
             outcome = "INCORRECT"
             self.num_incorrect_steps += 1
 
-        self.logger.log_step(selection, action, inputs['value'], outcome, [selection])
+        self.logger.log_step(selection, action, inputs['value'], outcome, step_name=selection, kcs=[selection])
 
         if reward == -1.0:
             return reward
 
         if selection == "done":
-            print("DONE! Only took %i steps." % (self.num_correct_steps + self.num_incorrect_steps))
+            # print("DONE! Only took %i steps." % (self.num_correct_steps + self.num_incorrect_steps))
             # self.render()
             # print()
             # pprint(self.state)
@@ -371,7 +375,7 @@ class MultiColumnAdditionSymbolic:
         demo = self.get_demo()
         feedback_text = "selection: %s, action: %s, input: %s" % (demo[0],
                 demo[1], demo[2]['value'])
-        self.logger.log_hint(feedback_text, [demo[0]])
+        self.logger.log_hint(feedback_text, step_name=demo[0], kcs=[demo[0]])
         self.num_hints += 1
 
         return demo
@@ -428,58 +432,42 @@ class MultiColumnAdditionSymbolic:
         raise Exception("request demo - logic missing")
 
 
-class MultiColumnAdditionOppEnv(BaseOppEnv):
-
-    def __init__(self):
-        super().__init__(MultiColumnAdditionSymbolic, max_depth=2)
-
-    def get_rl_operators(self):
-        return [
-                ('copy', 1),
-                ('add', 2),
-                ('mod10', 1),
-                ('div10', 1)
-                ]
-
 class MultiColumnAdditionDigitsEnv(gym.Env):
     metadata = {'render.modes': ['human']}
 
-    def get_dv_training(self):
-        empty = {attr: '' for attr in self.tutor.state if attr != 'operator'}
-
-        training_data = [empty]
-
-        for i in range(1, 10):
-            s = {attr: str(i) for attr in self.tutor.state if attr != 'operator'}
-            training_data.append(s)
-
-        return training_data
-
-    def get_rl_state(self):
-        return self.tutor.state
-
     def __init__(self):
         self.tutor = MultiColumnAdditionSymbolic()
         n_selections = len(self.tutor.get_possible_selections())
-        self.dv = DictVectorizer()
-        transformed_training = self.dv.fit_transform(self.get_dv_training())
-        n_features = transformed_training.shape[1]
-
+        n_features = 110
+        self.dv = OnlineDictVectorizer(n_features)
         self.observation_space = spaces.Box(low=0.0,
                 high=1.0, shape=(1, n_features), dtype=np.float32)
         self.action_space = spaces.MultiDiscrete([n_selections, 10])
+        self.n_steps = 0
+        self.max_steps = 5000
+
+    def get_rl_state(self):
+        return self.tutor.state
 
     def step(self, action):
+        self.n_steps += 1
+
         s, a, i = self.decode(action)
         # print(s, a, i)
         # print()
         reward = self.tutor.apply_sai(s, a, i)
+        # self.render()
         # print(reward)
-        
-        state = self.get_rl_state()
+        state = self.tutor.state
         # pprint(state)
-        obs = self.dv.transform([state])[0].toarray()
+        obs = self.dv.fit_transform([state])[0]
         done = (s == 'done' and reward == 1.0)
+
+        # have a max steps for a given problem.
+        # When we hit that we're done regardless.
+        if self.n_steps > self.max_steps:
+            done = True
+
         info = {}
 
         return obs, reward, done, info
@@ -505,9 +493,11 @@ class MultiColumnAdditionDigitsEnv(gym.Env):
         return s, a, i
 
     def reset(self):
+        self.n_steps = 0
         self.tutor.set_random_problem()
+        # self.render()
         state = self.get_rl_state()
-        obs = self.dv.transform([state])[0].toarray()
+        obs = self.dv.fit_transform([state])[0]
         return obs
 
     def render(self, mode='human', close=False):
diff --git a/tutorenvs/utils.py b/tutorenvs/utils.py
index 07decd2..478a8d6 100644
--- a/tutorenvs/utils.py
+++ b/tutorenvs/utils.py
@@ -9,6 +9,25 @@ from gym import error, spaces, utils
 from sklearn.feature_extraction import DictVectorizer
 import numpy as np
 
+
+class StubLogger():
+
+    def __init__(self):
+        pass
+
+    def set_student(self, student_id=None):
+        pass
+
+    def set_problem(self, problem_name=None):
+        pass
+
+    def log_hint(self, feedback_text="", step_name=None, kcs=None):
+        pass
+
+    def log_step(self, selection="", action="", inp="", outcome="", step_name=None, kcs=None):
+        pass
+
+
 class DataShopLogger():
 
     def __init__(self, domain = "tutorenv", extra_kcs=None):
@@ -64,7 +83,7 @@ class DataShopLogger():
         self.problem_start = datetime.fromtimestamp(self.time).strftime('%m/%d/%Y %H:%M:%S')
         self.step_count = 1
 
-    def log_hint(self, feedback_text, kcs=None):
+    def log_hint(self, feedback_text, step_name=None, kcs=None):
         if self.student_id is None:
             raise Exception("No student ID")
         if self.problem_name is None:
@@ -81,6 +100,9 @@ class DataShopLogger():
         inp = ""
         outcome = "HINT"
 
+        if step_name is None:
+            step_name = self.step_count
+
         datum = [self.student_id,
                  self.session_id,
                  transaction_id,
@@ -91,7 +113,8 @@ class DataShopLogger():
                  self.level_domain,
                  self.problem_name,
                  self.problem_start,
-                 self.step_count,
+                 #self.step_count,
+                 step_name,
                  selection,
                  action,
                  inp,
@@ -107,7 +130,7 @@ class DataShopLogger():
         with open(self.filename, 'a+') as fout:
             fout.write("\t".join(str(v) for v in datum) + "\n")
 
-    def log_step(self, selection, action, inp, outcome, kcs=None):
+    def log_step(self, selection, action, inp, outcome, step_name=None, kcs=None):
         if self.student_id is None:
             raise Exception("No student ID")
         if self.problem_name is None:
@@ -121,6 +144,9 @@ class DataShopLogger():
         self.step_count += 1
         feedback_text = ""
 
+        if step_name is None:
+            step_name = self.step_count
+
         datum = [self.student_id,
                  self.session_id,
                  transaction_id,
@@ -131,7 +157,7 @@ class DataShopLogger():
                  self.level_domain,
                  self.problem_name,
                  self.problem_start,
-                 self.step_count,
+                 step_name,
                  selection,
                  action,
                  inp,
@@ -147,6 +173,34 @@ class DataShopLogger():
         with open(self.filename, 'a+') as fout:
             fout.write("\t".join(str(v) for v in datum) + "\n")
 
+class MultiDiscreteToDiscreteWrapper(gym.ActionWrapper):
+
+    def __init__(self, env):
+        super().__init__(env)
+        assert isinstance(env.action_space, gym.spaces.MultiDiscrete), \
+            "Should only be used to wrap envs with MuliDiscrete actions."
+        self.action_vec = self.action_space.nvec
+        self.action_space = gym.spaces.Discrete(np.prod(self.action_vec))
+
+    # def convert(act):
+    #     discrete_act = 0
+    #     for i, v in enumerate(act):
+    #         discrete_act += (v * np.prod(self.action_vec[i+1:]))
+    #     return discrete_act
+
+    # def unconvert(discrete_act):
+    #     act = np.zeros_like(self.action_vec)
+    #     for i in range(len(self.action_vec)):
+    #         act[i] = discrete_act // np.prod(self.action_vec[i+1:])
+    #         discrete_act = discrete_act % np.prod(self.action_vec[i+1:])
+    #     return act
+
+    def action(self, discrete_act):
+        act = np.zeros_like(self.action_vec)
+        for i in range(len(self.action_vec)):
+            act[i] = discrete_act // np.prod(self.action_vec[i+1:])
+            discrete_act = discrete_act % np.prod(self.action_vec[i+1:])
+        return act
 
 class OnlineDictVectorizer():
 
-- 
GitLab