diff --git a/sandbox/multicolumn/tune_ppo.py b/sandbox/multicolumn/tune_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..4b29c3e2f48ec0ca43d17f774f4c2ffb6962cbc2 --- /dev/null +++ b/sandbox/multicolumn/tune_ppo.py @@ -0,0 +1,245 @@ +from typing import Dict +from typing import Any +from typing import Union +from typing import Callable +import tempfile + +import gym +import optuna +from torch import nn as nn +from stable_baselines3 import PPO +from stable_baselines3.ppo import MlpPolicy +# from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.vec_env import DummyVecEnv +# from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.monitor import load_results + +import tutorenvs + + +def linear_schedule( + initial_value: Union[float, str]) -> Callable[[float], float]: + """ + Linear learning rate schedule. + + :param initial_value: (float or str) + :return: (function) + """ + if isinstance(initial_value, str): + initial_value = float(initial_value) + + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0 + :param progress_remaining: (float) + :return: (float) + """ + return progress_remaining * initial_value + + return func + + +def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]: + """ + Sampler for PPO hyperparams. + + :param trial: + :return: + """ + batch_size = trial.suggest_categorical("batch_size", + [8, 16, 32, 64, 128, 256, 512]) + n_steps = trial.suggest_categorical( + "n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + gamma = trial.suggest_categorical("gamma", [0.0]) + # 0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]) + learning_rate = trial.suggest_loguniform("lr", 1e-8, 1) + # lr_schedule = "constant" + # Uncomment to enable learning rate schedule + lr_schedule = trial.suggest_categorical('lr_schedule', + ['linear', 'constant']) + ent_coef = trial.suggest_loguniform("ent_coef", 0.00000000001, 0.1) + clip_range = trial.suggest_categorical("clip_range", + [0.05, 0.1, 0.2, 0.3, 0.4]) + n_epochs = trial.suggest_categorical("n_epochs", [1, 5, 10, 20]) + gae_lambda = trial.suggest_categorical( + "gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) + max_grad_norm = trial.suggest_categorical( + "max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5]) + vf_coef = trial.suggest_uniform("vf_coef", 0, 1) + net_arch = trial.suggest_categorical("net_arch", + ["tiny", "small", "medium"]) + shared_arch = trial.suggest_categorical("shared_arch", [True, False]) + # Uncomment for gSDE (continuous actions) + # log_std_init = trial.suggest_uniform("log_std_init", -4, 1) + # Uncomment for gSDE (continuous action) + # sde_sample_freq = trial.suggest_categorical("sde_sample_freq", [-1, 8, + # 16, 32, 64, 128, 256]) + # Orthogonal initialization + ortho_init = False + # ortho_init = trial.suggest_categorical('ortho_init', [False, True]) + # activation_fn = trial.suggest_categorical('activation_fn', ['tanh', + # 'relu', 'elu', 'leaky_relu']) + activation_fn = trial.suggest_categorical("activation_fn", + ["tanh", "relu"]) + + # TODO: account when using multiple envs + if batch_size > n_steps: + batch_size = n_steps + + if lr_schedule == "linear": + learning_rate = linear_schedule(learning_rate) + + # Independent networks usually work best + # when not working with images + net_arch = { + True: { + "tiny": [32, dict(pi=[32], vf=[32])], + "small": [64, dict(pi=[64], vf=[64])], + "medium": [128, dict(pi=[128], vf=[128])], + }, + False: { + "tiny": [dict(pi=[32, 32], vf=[32, 32])], + "small": [dict(pi=[64, 64], vf=[64, 64])], + "medium": [dict(pi=[128, 128], vf=[128, 128])], + } + }[shared_arch][net_arch] + + activation_fn = { + "tanh": nn.Tanh, + "relu": nn.ReLU, + "elu": nn.ELU, + "leaky_relu": nn.LeakyReLU + }[activation_fn] + + return { + "n_steps": + n_steps, + "batch_size": + batch_size, + "gamma": + gamma, + "learning_rate": + learning_rate, + "ent_coef": + ent_coef, + "clip_range": + clip_range, + "n_epochs": + n_epochs, + "gae_lambda": + gae_lambda, + "max_grad_norm": + max_grad_norm, + "vf_coef": + vf_coef, + # "sde_sample_freq": sde_sample_freq, + "policy_kwargs": + dict( + # log_std_init=log_std_init, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + ), + } + + +class TrialCallback(BaseCallback): + """ + Callback used for evaluating and reporting a trial. + """ + def __init__( + self, + trial: optuna.Trial, + log_dir: str, + n_eval_episodes: int = 10, + eval_freq: int = 10000, + min_eval: float = -3000, + verbose: int = 0, + ): + super(TrialCallback, self).__init__(verbose) + + self.eval_freq = eval_freq + self.n_eval_episodes = n_eval_episodes + self.log_dir = log_dir + self.trial = trial + self.eval_idx = 0 + self.is_pruned = False + self.min_eval = min_eval + + def _on_step(self) -> bool: + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + results = load_results(self.log_dir) + if len(results) < self.n_eval_episodes: + return True + avg_last_n = results['r'][-self.n_eval_episodes:].mean() + self.eval_idx += 1 + # report best or report current ? + # report num_timesteps or elasped time ? + self.trial.report(avg_last_n, self.eval_idx) + # print('Idx:', self.eval_idx, 'Avg_last_n', avg_last_n) + + # Prune trial if need + if avg_last_n < self.min_eval or self.trial.should_prune(): + self.is_pruned = True + return False + + return True + + +def objective(trial: optuna.Trial) -> float: + n_eval_episodes = 10 + eval_freq = 5000 + n_steps = 100000 + + with tempfile.TemporaryDirectory() as log_dir: + env = DummyVecEnv([ + lambda: Monitor(gym.make('MulticolumnArithSymbolic-v0'), log_dir)]) + + ppo_args = sample_ppo_params(trial) + + model = PPO(MlpPolicy, env, + tensorboard_log="./tensorboard_ppo_multi/", + **ppo_args) + # gamma=0.1, + # tensorboard_log="./tensorboard/v0/") + callback = TrialCallback(trial, log_dir, verbose=1, + n_eval_episodes=n_eval_episodes, + eval_freq=eval_freq) + + try: + model.learn(total_timesteps=n_steps, callback=callback) + model.env.close() + except Exception as e: + model.env.close() + print(e) + raise optuna.exceptions.TrialPruned() + + is_pruned = callback.is_pruned + del model.env + del model + + if is_pruned: + raise optuna.exceptions.TrialPruned() + + results = load_results(log_dir) + avg_last_n = results['r'][-n_eval_episodes:].mean() + # print('Final avg_last_n:', avg_last_n) + return avg_last_n + + +if __name__ == "__main__": + + # multiprocess environment + # env = make_vec_env('MulticolumnArithSymbolic-v0', n_envs=1) + + study = optuna.create_study(pruner=optuna.pruners.MedianPruner( + n_warmup_steps=20000), direction="maximize") + study.optimize(objective, n_trials=1000, n_jobs=1) + + print("BEST") + print(study.best_params) + + # while True: + # Train diff --git a/tutorenvs/multicolumn.py b/tutorenvs/multicolumn.py index 1d34d3966b767a3769788808e09567a657c0473b..940fc0e08ffffe1e2d4a58a81665528ce7ecd984 100644 --- a/tutorenvs/multicolumn.py +++ b/tutorenvs/multicolumn.py @@ -29,7 +29,6 @@ class MultiColumnAdditionSymbolic: Creates a state and sets a random problem. """ if logger is None: - print("CREATING LOGGER") # self.logger = DataShopLogger('MulticolumnAdditionTutor', extra_kcs=['field']) self.logger = StubLogger() else: diff --git a/tutorenvs/utils.py b/tutorenvs/utils.py index 478a8d63451bc424fd14d671f8e87cb4cd27fc4a..d2465feaa67b1410b29989dff7602a47b0f1912f 100644 --- a/tutorenvs/utils.py +++ b/tutorenvs/utils.py @@ -2,17 +2,19 @@ import os import time import uuid from datetime import datetime -from pprint import pprint +import logging import gym -from gym import error, spaces, utils -from sklearn.feature_extraction import DictVectorizer +from gym import spaces import numpy as np +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) -class StubLogger(): +class StubLogger(): def __init__(self): + log.info("StubLogger Created") pass def set_student(self, student_id=None): @@ -24,36 +26,32 @@ class StubLogger(): def log_hint(self, feedback_text="", step_name=None, kcs=None): pass - def log_step(self, selection="", action="", inp="", outcome="", step_name=None, kcs=None): + def log_step(self, + selection="", + action="", + inp="", + outcome="", + step_name=None, + kcs=None): pass class DataShopLogger(): - - def __init__(self, domain = "tutorenv", extra_kcs=None): + def __init__(self, domain="tutorenv", extra_kcs=None): + log.info("DataShop Logger Created") # Create log file if not os.path.exists("log/"): os.mkdir("log/") - self.filename = "log/" + domain + "_" + time.strftime("%Y-%m-%d-%H-%M-%s") + ".txt" - - headers = ['Anon Student Id', - 'Session Id', - 'Transaction Id', - 'Time', - 'Time Zone', - 'Student Response Type', - 'Tutor Response Type', - 'Level (Domain)', - 'Problem Name', - 'Problem Start Time', - 'Step Name', - 'Selection', - 'Action', - 'Input', - 'Feedback Text', - 'Outcome', - 'CF (Problem Context)', - 'KC (Single-KC)'] + self.filename = "log/" + domain + "_" + time.strftime( + "%Y-%m-%d-%H-%M-%s") + ".txt" + + headers = [ + 'Anon Student Id', 'Session Id', 'Transaction Id', 'Time', + 'Time Zone', 'Student Response Type', 'Tutor Response Type', + 'Level (Domain)', 'Problem Name', 'Problem Start Time', + 'Step Name', 'Selection', 'Action', 'Input', 'Feedback Text', + 'Outcome', 'CF (Problem Context)', 'KC (Single-KC)' + ] if extra_kcs is not None: for kc in extra_kcs: @@ -80,7 +78,8 @@ class DataShopLogger(): problem_name = uuid.uuid4() self.problem_name = problem_name self.time += 1 - self.problem_start = datetime.fromtimestamp(self.time).strftime('%m/%d/%Y %H:%M:%S') + self.problem_start = datetime.fromtimestamp( + self.time).strftime('%m/%d/%Y %H:%M:%S') self.step_count = 1 def log_hint(self, feedback_text, step_name=None, kcs=None): @@ -103,25 +102,27 @@ class DataShopLogger(): if step_name is None: step_name = self.step_count - datum = [self.student_id, - self.session_id, - transaction_id, - time, - self.timezone, - student_response, - tutor_response, - self.level_domain, - self.problem_name, - self.problem_start, - #self.step_count, - step_name, - selection, - action, - inp, - feedback_text, - outcome, - "", - "Single-KC"] + datum = [ + self.student_id, + self.session_id, + transaction_id, + time, + self.timezone, + student_response, + tutor_response, + self.level_domain, + self.problem_name, + self.problem_start, + # self.step_count, + step_name, + selection, + action, + inp, + feedback_text, + outcome, + "", + "Single-KC" + ] if kcs is not None: for kc in kcs: @@ -130,7 +131,13 @@ class DataShopLogger(): with open(self.filename, 'a+') as fout: fout.write("\t".join(str(v) for v in datum) + "\n") - def log_step(self, selection, action, inp, outcome, step_name=None, kcs=None): + def log_step(self, + selection, + action, + inp, + outcome, + step_name=None, + kcs=None): if self.student_id is None: raise Exception("No student ID") if self.problem_name is None: @@ -147,24 +154,12 @@ class DataShopLogger(): if step_name is None: step_name = self.step_count - datum = [self.student_id, - self.session_id, - transaction_id, - time, - self.timezone, - student_response, - tutor_response, - self.level_domain, - self.problem_name, - self.problem_start, - step_name, - selection, - action, - inp, - feedback_text, - outcome, - "", - "Single-KC"] + datum = [ + self.student_id, self.session_id, transaction_id, time, + self.timezone, student_response, tutor_response, self.level_domain, + self.problem_name, self.problem_start, step_name, selection, + action, inp, feedback_text, outcome, "", "Single-KC" + ] if kcs is not None: for kc in kcs: @@ -173,8 +168,8 @@ class DataShopLogger(): with open(self.filename, 'a+') as fout: fout.write("\t".join(str(v) for v in datum) + "\n") -class MultiDiscreteToDiscreteWrapper(gym.ActionWrapper): +class MultiDiscreteToDiscreteWrapper(gym.ActionWrapper): def __init__(self, env): super().__init__(env) assert isinstance(env.action_space, gym.spaces.MultiDiscrete), \ @@ -198,12 +193,12 @@ class MultiDiscreteToDiscreteWrapper(gym.ActionWrapper): def action(self, discrete_act): act = np.zeros_like(self.action_vec) for i in range(len(self.action_vec)): - act[i] = discrete_act // np.prod(self.action_vec[i+1:]) - discrete_act = discrete_act % np.prod(self.action_vec[i+1:]) + act[i] = discrete_act // np.prod(self.action_vec[i + 1:]) + discrete_act = discrete_act % np.prod(self.action_vec[i + 1:]) return act -class OnlineDictVectorizer(): +class OnlineDictVectorizer(): def __init__(self, n_features): self.n_features = n_features self.separator = '=' @@ -217,7 +212,7 @@ class OnlineDictVectorizer(): """ Given a set of X, it updates the key with any new values. """ - + for x in X: for f, v in x.items(): if isinstance(v, str): @@ -299,18 +294,15 @@ class BaseOppEnv(gym.Env): self.dv = OnlineDictVectorizer(n_features=n_features) self.observation_space = spaces.Box(low=0.0, - high=1.0, shape=(1, n_features), dtype=np.float32) - self.action_space = spaces.MultiDiscrete([n_selections, n_operators, - n_args, n_args]) + high=1.0, + shape=(1, n_features), + dtype=np.float32) + self.action_space = spaces.MultiDiscrete( + [n_selections, n_operators, n_args, n_args]) def get_rl_operators(self): - return [ - ('copy', 1), - ('add', 2), - ('multiply', 2), - ('mod10', 1), - ('div10', 1) - ] + return [('copy', 1), ('add', 2), ('multiply', 2), ('mod10', 1), + ('div10', 1)] def get_rl_state(self): # self.state = { @@ -358,7 +350,8 @@ class BaseOppEnv(gym.Env): # greater than 9 try: - new_relations['greater_than_9(%s)' % str(attr)] = float(attr_val) > 9 + new_relations['greater_than_9(%s)' % + str(attr)] = float(attr_val) > 9 except Exception: new_relations['greater_than_9(%s)' % str(attr)] = False @@ -377,7 +370,7 @@ class BaseOppEnv(gym.Env): s, a, i = self.decode(action) print(s, a, i) - + if isinstance(s, tuple): if s in self.internal_memory or i == '': reward = -1 @@ -394,7 +387,7 @@ class BaseOppEnv(gym.Env): # print(s, a, i) # print() # print(reward) - + state = self.get_rl_state() # pprint(state) obs = self.dv.fit_transform([state])[0] @@ -402,27 +395,26 @@ class BaseOppEnv(gym.Env): return obs, reward, done, info - def apply_rl_op(self, op, arg1, arg2): a1 = None a2 = None if arg1 in self.tutor.state: - a1 = self.tutor.state[arg1] + a1 = self.tutor.state[arg1] elif arg1 in self.internal_memory: a1 = self.internal_memory[arg1] else: raise ValueError('Element not in memory') if arg2 in self.tutor.state: - a2 = self.tutor.state[arg2] + a2 = self.tutor.state[arg2] elif arg2 in self.internal_memory: a2 = self.internal_memory[arg2] else: raise ValueError('Element not in memory') if op == "copy": - return a1 + return a1 elif op == "add": return str(int(a1) + int(a2)) elif op == "multiply": @@ -453,7 +445,7 @@ class BaseOppEnv(gym.Env): a = "ButtonPressed" else: a = "UpdateField" - + if s == "done": v = -1 if s == "check_convert":