diff --git a/.gitignore b/.gitignore index 03315ebdff262413233c8e2b7b1dd46883d5f9bd..add424ca20cbd1c8c5be413519ef91718c72f3d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.python-version + AUTHORS ChangeLog diff --git a/TutorEnvs/fractions.py b/TutorEnvs/fractions.py index 181a9666f36c2d3c5a7f127740c6cddeae02fa12..382f31cb05139f8c29535d238abbe61301fef165 100644 --- a/TutorEnvs/fractions.py +++ b/TutorEnvs/fractions.py @@ -41,12 +41,24 @@ class FractionArithSymbolic: Returns the current state as a dict. """ state_output = {attr: - {'id': attr, 'value': self.state[attr], 'type': 'TextField', - 'contentEditable': self.state[attr] == ""} + {'id': attr, 'value': self.state[attr], + 'type': 'TextField', + 'contentEditable': self.state[attr] == "", + 'dom_class': 'CTATTable--cell', + 'above': '', + 'below': '', + 'to_left': '', + 'to_right': '' + } for attr in self.state} state_output['done'] = { 'id': 'done', - 'type': 'Button' + 'type': 'Component', + 'dom_class': 'CTATDoneButton', + 'above': '', + 'below': '', + 'to_left': '', + 'to_right': '' } return state_output diff --git a/test_al.py b/test_al.py new file mode 100644 index 0000000000000000000000000000000000000000..824866f68aaff243b84574303acec26fc56ed4b2 --- /dev/null +++ b/test_al.py @@ -0,0 +1,54 @@ +from apprentice.agents.ModularAgent import ModularAgent +from apprentice.working_memory.representation import Sai + +from tutorenvs.fractions import FractionArithSymbolic + + + +def run_training(agent, n=10): + + env = FractionArithSymbolic() + + p = 0 + + while p < n: + + state = env.get_state() + response = agent.request(state) + + if response == {}: + print('hint') + selection, action, inputs = env.request_demo() + sai = Sai(selection=selection, + action=action, + inputs=inputs) + + else: + sai = Sai(selection=response['selection'], + action=response['action'], + inputs=response['inputs']) + + reward = env.apply_sai(sai.selection, sai.action, sai.inputs) + print('reward', reward) + + agent.train(state, sai, reward) + + if sai.selection == "done" and reward == 1.0: + p += 1 + +if __name__ == "__main__": + args = {"function_set" : ["RipFloatValue","Add", + 'Multiply', + "Subtract", + # "Numerator_Multiply", "Cross_Multiply", + "Divide"], + + "feature_set" : ["Equals"], "planner" : "numba", "search_depth" : 2, + "when_learner": "trestle", "where_learner": "FastMostSpecific", + "state_variablization" : "whereappend", "strip_attrs" : + ["to_left","to_right","above","below","type","id","offsetParent","dom_class"], + "when_args" : { "cross_rhs_inference" : "none" } } + + agent = ModularAgent(**args) + + run_training(agent, n = 100)