From 34a5103e2b8039f0e16cc4b43e9e5ece5ae8a14b Mon Sep 17 00:00:00 2001 From: Chris MacLellan <2348-cm3786@users.noreply.gitlab.cci.drexel.edu> Date: Fri, 15 Jan 2021 15:30:22 -0500 Subject: [PATCH] set lower min eval for fraction ppo tuning, disabled logging for multicolumn, added dual decision tree model for fractions --- sandbox/fractions/run_dual_decision_tree.py | 93 +++++++++++++++++++++ sandbox/fractions/tune_ppo.py | 2 +- tutorenvs/fractions.py | 9 +- tutorenvs/multicolumn.py | 9 +- 4 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 sandbox/fractions/run_dual_decision_tree.py diff --git a/sandbox/fractions/run_dual_decision_tree.py b/sandbox/fractions/run_dual_decision_tree.py new file mode 100644 index 0000000..5a6443f --- /dev/null +++ b/sandbox/fractions/run_dual_decision_tree.py @@ -0,0 +1,93 @@ +from tutorenvs.fractions import FractionArithSymbolic + +from sklearn.tree import DecisionTreeClassifier +from sklearn.feature_extraction import DictVectorizer + + +def train_tree(n=10, logger=None): + X = [] + y_sel = [] + y_inp = [] + dv = DictVectorizer() + selections = [] + selection_mapping = {} + rev_selection_mapping = {} + selection_tree = DecisionTreeClassifier() + + inputs = [] + input_mapping = {} + rev_input_mapping = {} + input_tree = DecisionTreeClassifier() + + env = FractionArithSymbolic() + + p = 0 + hints = 0 + + while p < n: + + # make a copy of the state + state = {a: env.state[a] for a in env.state} + env.render() + + if rev_selection_mapping == {}: + sai = None + else: + vstate = dv.transform([state]) + sel = rev_selection_mapping[selection_tree.predict(vstate)[0]] + if sel == 'done': + act = 'ButtonPressed' + else: + act = "UpdateField" + inp = rev_input_mapping[input_tree.predict(vstate)[0]] + sai = (sel, act, inp) + + if sai is None: + hints += 1 + # print('hint') + sai = env.request_demo() + sai = (sai[0], sai[1], sai[2]['value']) + + reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]}) + # print('reward', reward) + + if reward < 0: + hints += 1 + # print('hint') + sai = env.request_demo() + sai = (sai[0], sai[1], sai[2]['value']) + reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]}) + + X.append(state) + y_sel.append(sai[0]) + y_inp.append(sai[2]) + + Xv = dv.fit_transform(X) + + selections = list(set(y_sel)) + selection_mapping = {l: i for i, l in enumerate(selections)} + rev_selection_mapping = {i: l for i, l in enumerate(selections)} + + inputs = list(set(y_inp)) + input_mapping = {l: i for i, l in enumerate(inputs)} + rev_input_mapping = {i: l for i, l in enumerate(inputs)} + + yv_sel = [selection_mapping[i] for i in y_sel] + yv_inp = [input_mapping[i] for i in y_inp] + + selection_tree.fit(Xv, yv_sel) + input_tree.fit(Xv, yv_inp) + + if sai[0] == "done" and reward == 1.0: + print("Problem %s of %s" % (p, n)) + print("# of hints = {}".format(hints)) + hints = 0 + p += 1 + + return selection_tree, input_tree + + +if __name__ == "__main__": + + for _ in range(1): + tree = train_tree(500) diff --git a/sandbox/fractions/tune_ppo.py b/sandbox/fractions/tune_ppo.py index 0a4e1f6..9ec5c0b 100644 --- a/sandbox/fractions/tune_ppo.py +++ b/sandbox/fractions/tune_ppo.py @@ -134,7 +134,7 @@ class TrialCallback(BaseCallback): log_dir: str, n_eval_episodes: int = 10, eval_freq: int = 10000, - min_eval: float = -600, + min_eval: float = -1500, verbose: int = 0, ): super(TrialCallback, self).__init__(verbose) diff --git a/tutorenvs/fractions.py b/tutorenvs/fractions.py index af81014..1637756 100644 --- a/tutorenvs/fractions.py +++ b/tutorenvs/fractions.py @@ -21,12 +21,15 @@ pil_logger.setLevel(logging.INFO) class FractionArithSymbolic: - def __init__(self): + def __init__(self, logger=None): """ Creates a state and sets a random problem. """ - # self.logger = DataShopLogger('FractionsTutor', extra_kcs=['ptype_field']) - self.logger = StubLogger() + if logger is None: + # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) + self.logger = StubLogger() + else: + self.logger = logger self.logger.set_student() self.set_random_problem() # self.reset("", "", "", "", "") diff --git a/tutorenvs/multicolumn.py b/tutorenvs/multicolumn.py index 088f7c8..1d7194d 100644 --- a/tutorenvs/multicolumn.py +++ b/tutorenvs/multicolumn.py @@ -1,6 +1,7 @@ from random import randint from random import choice from pprint import pprint +import logging import cv2 # pytype:disable=import-error import gym @@ -15,6 +16,10 @@ from tutorenvs.utils import OnlineDictVectorizer from tutorenvs.utils import DataShopLogger from tutorenvs.utils import StubLogger +pil_logger = logging.getLogger('PIL') +pil_logger.setLevel(logging.INFO) + + def custom_add(a, b): if a == '': a = '0' @@ -29,8 +34,8 @@ class MultiColumnAdditionSymbolic: Creates a state and sets a random problem. """ if logger is None: - self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) - # self.logger = StubLogger() + # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) + self.logger = StubLogger() else: self.logger = logger self.logger.set_student() -- GitLab