From 5f9d2516ad2c210f538ff553bc445638d7c57ae1 Mon Sep 17 00:00:00 2001
From: Chris MacLellan <2348-cm3786@users.noreply.gitlab.cci.drexel.edu>
Date: Wed, 30 Sep 2020 21:45:55 -0400
Subject: [PATCH] working example with AL
---
.gitignore | 2 ++
TutorEnvs/fractions.py | 18 +++++++++++---
test_al.py | 54 ++++++++++++++++++++++++++++++++++++++++++
3 files changed, 71 insertions(+), 3 deletions(-)
create mode 100644 test_al.py
diff --git a/.gitignore b/.gitignore
index 03315eb..add424c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+.python-version
+
AUTHORS
ChangeLog
diff --git a/TutorEnvs/fractions.py b/TutorEnvs/fractions.py
index 181a966..382f31c 100644
--- a/TutorEnvs/fractions.py
+++ b/TutorEnvs/fractions.py
@@ -41,12 +41,24 @@ class FractionArithSymbolic:
Returns the current state as a dict.
"""
state_output = {attr:
- {'id': attr, 'value': self.state[attr], 'type': 'TextField',
- 'contentEditable': self.state[attr] == ""}
+ {'id': attr, 'value': self.state[attr],
+ 'type': 'TextField',
+ 'contentEditable': self.state[attr] == "",
+ 'dom_class': 'CTATTable--cell',
+ 'above': '',
+ 'below': '',
+ 'to_left': '',
+ 'to_right': ''
+ }
for attr in self.state}
state_output['done'] = {
'id': 'done',
- 'type': 'Button'
+ 'type': 'Component',
+ 'dom_class': 'CTATDoneButton',
+ 'above': '',
+ 'below': '',
+ 'to_left': '',
+ 'to_right': ''
}
return state_output
diff --git a/test_al.py b/test_al.py
new file mode 100644
index 0000000..824866f
--- /dev/null
+++ b/test_al.py
@@ -0,0 +1,54 @@
+from apprentice.agents.ModularAgent import ModularAgent
+from apprentice.working_memory.representation import Sai
+
+from tutorenvs.fractions import FractionArithSymbolic
+
+
+
+def run_training(agent, n=10):
+
+ env = FractionArithSymbolic()
+
+ p = 0
+
+ while p < n:
+
+ state = env.get_state()
+ response = agent.request(state)
+
+ if response == {}:
+ print('hint')
+ selection, action, inputs = env.request_demo()
+ sai = Sai(selection=selection,
+ action=action,
+ inputs=inputs)
+
+ else:
+ sai = Sai(selection=response['selection'],
+ action=response['action'],
+ inputs=response['inputs'])
+
+ reward = env.apply_sai(sai.selection, sai.action, sai.inputs)
+ print('reward', reward)
+
+ agent.train(state, sai, reward)
+
+ if sai.selection == "done" and reward == 1.0:
+ p += 1
+
+if __name__ == "__main__":
+ args = {"function_set" : ["RipFloatValue","Add",
+ 'Multiply',
+ "Subtract",
+ # "Numerator_Multiply", "Cross_Multiply",
+ "Divide"],
+
+ "feature_set" : ["Equals"], "planner" : "numba", "search_depth" : 2,
+ "when_learner": "trestle", "where_learner": "FastMostSpecific",
+ "state_variablization" : "whereappend", "strip_attrs" :
+ ["to_left","to_right","above","below","type","id","offsetParent","dom_class"],
+ "when_args" : { "cross_rhs_inference" : "none" } }
+
+ agent = ModularAgent(**args)
+
+ run_training(agent, n = 100)
--
GitLab