Skip to content
Snippets Groups Projects
Commit db0976b5 authored by Chris MacLellan's avatar Chris MacLellan
Browse files

added code for generating predictions with BKT and DKT

parent fac8d19e
Branches
No related tags found
No related merge requests found
.DS_Store
*.p
*.swp
......
......
with open('dkt_data.csv', 'r') as fin:
p = []
X = []
y = []
for row in fin:
row = row.strip().split(',')
if row[1] == 'correct':
y.append(1)
else:
y.append(0)
X.append({row[-1]: 1})
p.append(row[0])
print(p)
print(X)
print(y)
......@@ -2,32 +2,51 @@ from dkt_torch.model_fitting import predict
if __name__ == "__main__":
test_kcs = [{'M JCommTable6.R0C0': 1},
{'M JCommTable6.R1C0': 1},
{'M done': 1}]
prob = ['AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 1_8_times_4_8', 'M 1_8_times_4_8', 'M 1_8_times_4_8', 'M 8_2_times_1_7', 'M 8_2_times_1_7', 'M 8_2_times_1_7', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 1_6_plus_4_6', 'AS 1_6_plus_4_6', 'AS 1_6_plus_4_6', 'M 6_3_times_1_5', 'M 6_3_times_1_5', 'M 6_3_times_1_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2']
test_y = [0, 0, 0]
test_kcs = [{'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}, {'AD done': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD done': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD done': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD done': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}, {'AD JCommTable6.R1C0': 1}, {'M JCommTable8.R0C0': 1}, {'M JCommTable4.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M JCommTable6.R1C0': 1}, {'M done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'M JCommTable6.R0C0': 1}, {'M JCommTable6.R1C0': 1}, {'M done': 1}, {'AS JCommTable8.R0C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS JCommTable6.R1C0': 1}, {'AS done': 1}, {'AS JCommTable8.R0C0': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD done': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}]
test_y = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1]
# test_kcs = []
# test_y = []
pred = predict('model_params.p', test_kcs, test_y)
from pprint import pprint
print('test 1 all misses')
for p in pred:
pprint(p)
test_kcs = [{'M JCommTable6.R0C0': 1},
{'M JCommTable6.R1C0': 1},
{'M done': 1},
{'M JCommTable6.R0C0': 1},
{'M JCommTable6.R1C0': 1},
{'M done': 1}
]
test_y = [1, 1, 1, 1, 1, 1]
prev = "ZZZZ"
# print(pred[0]['AD JCommTable6.R1C0'])
pred = predict('model_params.p', test_kcs, test_y)
print(len(set(prob)))
from pprint import pprint
print('test 2 all hits')
for p in pred:
pprint(p)
transitions = [pred[0]]
pprint(pred[0])
for i, p in enumerate(pred[1:]):
if prob[i] != prev:
transitions.append(pred[i-1])
if prob[i] != prev and prev[:2] == "AD":
# print(pred[i-1]['AD JCommTable6.R1C0'])
pprint(pred[i-1])
prev = prob[i]
if prev[:2] == "AD":
# print(pred[-1]['AD JCommTable6.R1C0'])
pprint(pred[-1])
transitions.append(pred[-1])
import matplotlib.pyplot as plt
import numpy as np
subset = ['AD JCommTable4.R0C0', 'AD JCommTable4.R1C0', 'AD JCommTable5.R0C0', 'AD JCommTable5.R1C0', 'AD JCommTable6.R0C0', 'AD JCommTable6.R1C0', 'AD JCommTable8.R0C0', 'AD done', 'AS JCommTable6.R0C0', 'AS JCommTable6.R1C0', 'AS done', 'M JCommTable6.R0C0', 'M JCommTable6.R1C0', 'M done']
hm_data = np.array([[p[k] for k in subset] for p in transitions])
plt.imshow(hm_data, cmap='Greens', interpolation='nearest')
plt.show()
from dkt_torch.model_fitting import predict
if __name__ == "__main__":
prob = ['AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 3_6_plus_6_8', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 9_2_plus_4_7', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 5_6_plus_1_4', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 3_3_plus_7_8', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 6_8_plus_1_1', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'AD 5_4_plus_2_7', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 4_4_times_8_10', 'M 1_8_times_4_8', 'M 1_8_times_4_8', 'M 1_8_times_4_8', 'M 8_2_times_1_7', 'M 8_2_times_1_7', 'M 8_2_times_1_7', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 6_5_plus_7_5', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 3_9_plus_9_9', 'AS 1_6_plus_4_6', 'AS 1_6_plus_4_6', 'AS 1_6_plus_4_6', 'M 6_3_times_1_5', 'M 6_3_times_1_5', 'M 6_3_times_1_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_8_plus_4_5', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AD 9_7_plus_5_2', 'AS 1_5_plus_10_5', 'AS 1_5_plus_10_5', 'AS 1_5_plus_10_5', 'M 8_1_times_5_7', 'M 8_1_times_5_7', 'M 8_1_times_5_7', 'M 6_3_times_8_8', 'M 6_3_times_8_8', 'M 6_3_times_8_8', 'M 2_8_times_1_1', 'M 2_8_times_1_1', 'M 2_8_times_1_1', 'AS 9_4_plus_7_4', 'AS 9_4_plus_7_4', 'AS 9_4_plus_7_4', 'AS 2_10_plus_4_10', 'AS 2_10_plus_4_10', 'AS 2_10_plus_4_10']
test_kcs = [{'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}, {'AD done': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD done': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD done': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD done': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}, {'AD JCommTable6.R1C0': 1}, {'M JCommTable8.R0C0': 1}, {'M JCommTable4.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M JCommTable6.R1C0': 1}, {'M done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'M JCommTable6.R0C0': 1}, {'M JCommTable6.R1C0': 1}, {'M done': 1}, {'AS JCommTable8.R0C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS JCommTable6.R1C0': 1}, {'AS done': 1}, {'AS JCommTable8.R0C0': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD done': 1}, {'AD JCommTable4.R1C0': 1}, {'AD JCommTable8.R0C0': 1}, {'AD JCommTable4.R0C0': 1}, {'AD JCommTable5.R1C0': 1}, {'AD JCommTable5.R0C0': 1}, {'AD JCommTable6.R1C0': 1}, {'AD JCommTable6.R0C0': 1}, {'AD done': 1}, {'AS JCommTable6.R0C0': 1}, {'AS JCommTable6.R1C0': 1}, {'AS done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'M JCommTable6.R1C0': 1}, {'M JCommTable6.R0C0': 1}, {'M done': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}, {'AS JCommTable6.R1C0': 1}, {'AS JCommTable6.R0C0': 1}, {'AS done': 1}]
test_y = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# test_kcs = []
# test_y = []
pred = predict('model_params.p', test_kcs, test_y)
from pprint import pprint
prev = "ZZZZ"
# print(pred[0]['AD JCommTable6.R1C0'])
print(len(set(prob)))
print(set(prob))
transitions = [pred[0]]
p_type = ["I"]
pprint(pred[0])
for i, p in enumerate(pred[1:]):
print(prob[i], test_kcs[i])
if prob[i] != prev and i > 1:
transitions.append(pred[1:][i-1])
p_type.append(prob[i-1].split(" ")[0])
if prob[i] != prev and prev[:2] == "AD":
# print(pred[i-1]['AD JCommTable6.R1C0'])
# print(prev)
pprint(pred[1:][i-1]['AD JCommTable6.R1C0'])
prev = prob[i]
if prev[:2] == "AD":
# print(pred[-1]['AD JCommTable6.R1C0'])
pprint(pred[-1])
# transitions.append(pred[-1])
# p_type.append(prob[-1].split(" ")[0])
import matplotlib.pyplot as plt
import numpy as np
subset = ['AD JCommTable4.R0C0', 'AD JCommTable4.R1C0', 'AD JCommTable5.R0C0', 'AD JCommTable5.R1C0', 'AD JCommTable6.R0C0', 'AD JCommTable6.R1C0', 'AD JCommTable8.R0C0', 'AD done', 'AS JCommTable6.R0C0', 'AS JCommTable6.R1C0', 'AS done', 'M JCommTable6.R0C0', 'M JCommTable6.R1C0', 'M done']
field_mapping = {'AD JCommTable4.R0C0': "AD Left Convert Numerator",
'AD JCommTable4.R1C0': "AD Left Convert Denominator",
'AD JCommTable5.R0C0': "AD Right Convert Numerator",
'AD JCommTable5.R1C0': "AD Right Convert Denominator",
'AD JCommTable6.R0C0': "AD Answer Numerator",
'AD JCommTable6.R1C0': "AD Answer Denominator",
'AD JCommTable8.R0C0': "AD Convert Checkbox",
'AD done': "AD Done",
'AS JCommTable6.R0C0': "AS Answer Numerator",
'AS JCommTable6.R1C0': "AS Answer Denominator",
'AS done': "AS Done",
'M JCommTable6.R0C0': "M Answer Numerator",
'M JCommTable6.R1C0': "M Answer Denominator",
'M done': "M Done",
}
hm_data = np.array([[p[k] for k in subset] for p in transitions]).T
# hm_data = np.array([[p[k] for k in subset] for p in pred]).T
plt.imshow(hm_data, cmap='Greens', interpolation='nearest', vmin=0, vmax=1)
plt.yticks(np.arange(len(subset)), [field_mapping[s] for s in subset])
plt.xticks(np.arange(len(p_type)), p_type)
# plt.xticks(np.arange(len(prob)+1), ["I"] + ["" if p_type == prob[i-1] else p_type.split(" ")[0] for i, p_type in enumerate(prob)])
plt.show()
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment