diff --git a/sandbox/fractions/run_dual_decision_tree.py b/sandbox/fractions/run_dual_decision_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6443f34752e8eca0a700281ef148d807937312 --- /dev/null +++ b/sandbox/fractions/run_dual_decision_tree.py @@ -0,0 +1,93 @@ +from tutorenvs.fractions import FractionArithSymbolic + +from sklearn.tree import DecisionTreeClassifier +from sklearn.feature_extraction import DictVectorizer + + +def train_tree(n=10, logger=None): + X = [] + y_sel = [] + y_inp = [] + dv = DictVectorizer() + selections = [] + selection_mapping = {} + rev_selection_mapping = {} + selection_tree = DecisionTreeClassifier() + + inputs = [] + input_mapping = {} + rev_input_mapping = {} + input_tree = DecisionTreeClassifier() + + env = FractionArithSymbolic() + + p = 0 + hints = 0 + + while p < n: + + # make a copy of the state + state = {a: env.state[a] for a in env.state} + env.render() + + if rev_selection_mapping == {}: + sai = None + else: + vstate = dv.transform([state]) + sel = rev_selection_mapping[selection_tree.predict(vstate)[0]] + if sel == 'done': + act = 'ButtonPressed' + else: + act = "UpdateField" + inp = rev_input_mapping[input_tree.predict(vstate)[0]] + sai = (sel, act, inp) + + if sai is None: + hints += 1 + # print('hint') + sai = env.request_demo() + sai = (sai[0], sai[1], sai[2]['value']) + + reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]}) + # print('reward', reward) + + if reward < 0: + hints += 1 + # print('hint') + sai = env.request_demo() + sai = (sai[0], sai[1], sai[2]['value']) + reward = env.apply_sai(sai[0], sai[1], {'value': sai[2]}) + + X.append(state) + y_sel.append(sai[0]) + y_inp.append(sai[2]) + + Xv = dv.fit_transform(X) + + selections = list(set(y_sel)) + selection_mapping = {l: i for i, l in enumerate(selections)} + rev_selection_mapping = {i: l for i, l in enumerate(selections)} + + inputs = list(set(y_inp)) + input_mapping = {l: i for i, l in enumerate(inputs)} + rev_input_mapping = {i: l for i, l in enumerate(inputs)} + + yv_sel = [selection_mapping[i] for i in y_sel] + yv_inp = [input_mapping[i] for i in y_inp] + + selection_tree.fit(Xv, yv_sel) + input_tree.fit(Xv, yv_inp) + + if sai[0] == "done" and reward == 1.0: + print("Problem %s of %s" % (p, n)) + print("# of hints = {}".format(hints)) + hints = 0 + p += 1 + + return selection_tree, input_tree + + +if __name__ == "__main__": + + for _ in range(1): + tree = train_tree(500) diff --git a/sandbox/fractions/tune_ppo.py b/sandbox/fractions/tune_ppo.py index 0a4e1f66733d6347da50768fe3b92abc966fb0c4..9ec5c0b89e100be13b68a11738cc5b083d3c5fc4 100644 --- a/sandbox/fractions/tune_ppo.py +++ b/sandbox/fractions/tune_ppo.py @@ -134,7 +134,7 @@ class TrialCallback(BaseCallback): log_dir: str, n_eval_episodes: int = 10, eval_freq: int = 10000, - min_eval: float = -600, + min_eval: float = -1500, verbose: int = 0, ): super(TrialCallback, self).__init__(verbose) diff --git a/tutorenvs/fractions.py b/tutorenvs/fractions.py index af81014cde1d176ec68e2452e791e5d22568fffb..1637756750fee9514f49d869ba566e63c85e4a3f 100644 --- a/tutorenvs/fractions.py +++ b/tutorenvs/fractions.py @@ -21,12 +21,15 @@ pil_logger.setLevel(logging.INFO) class FractionArithSymbolic: - def __init__(self): + def __init__(self, logger=None): """ Creates a state and sets a random problem. """ - # self.logger = DataShopLogger('FractionsTutor', extra_kcs=['ptype_field']) - self.logger = StubLogger() + if logger is None: + # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) + self.logger = StubLogger() + else: + self.logger = logger self.logger.set_student() self.set_random_problem() # self.reset("", "", "", "", "") diff --git a/tutorenvs/multicolumn.py b/tutorenvs/multicolumn.py index 088f7c8782ab431b7b6ec6973ae4fd7fbbd6330a..1d7194d3f31419ed90eb8cd374893deff1d32a23 100644 --- a/tutorenvs/multicolumn.py +++ b/tutorenvs/multicolumn.py @@ -1,6 +1,7 @@ from random import randint from random import choice from pprint import pprint +import logging import cv2 # pytype:disable=import-error import gym @@ -15,6 +16,10 @@ from tutorenvs.utils import OnlineDictVectorizer from tutorenvs.utils import DataShopLogger from tutorenvs.utils import StubLogger +pil_logger = logging.getLogger('PIL') +pil_logger.setLevel(logging.INFO) + + def custom_add(a, b): if a == '': a = '0' @@ -29,8 +34,8 @@ class MultiColumnAdditionSymbolic: Creates a state and sets a random problem. """ if logger is None: - self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) - # self.logger = StubLogger() + # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) + self.logger = StubLogger() else: self.logger = logger self.logger.set_student()