-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path4*4 Gridworld using Temporal difference method(TD0).py
More file actions
134 lines (123 loc) · 4.22 KB
/
4*4 Gridworld using Temporal difference method(TD0).py
File metadata and controls
134 lines (123 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import numpy as np
class GridWorld:
def __init__(self):
# S O O O
# O O O *
# O * O O
# O * 0 T
self.actionSpace = ('U', 'D', 'L', 'R')
self.actions = {
(0, 0): ('D', 'R'),
(0, 1): ('L', 'D', 'R'),
(0, 2): ('L', 'D', 'R'),
(0, 3): ('L', 'D'),
(1, 0): ('U', 'D', 'R'),
(1, 1): ('U', 'L', 'D', 'R'),
(1, 2): ('U', 'L', 'D', 'R'),
(1, 3): ('U', 'L', 'D'),
(2, 0): ('U', 'D', 'R'),
(2, 1): ('U', 'L', 'D', 'R'),
(2, 2): ('U', 'L', 'D', 'R'),
(2, 3): ('U', 'L', 'D'),
(3, 0): ('U', 'R'),
(3, 1): ('U', 'L', 'R'),
(3, 2): ('U', 'L', 'R')
}
self.rewards = {(3, 3): 0.5, (1, 3): -0.3, (2, 1):-0.3, (3, 1):-0.3}#
self.explored = 0
self.exploited = 0
def getRandomPolicy(self):
policy = {}
for state in self.actions:
policy[state] = np.random.choice(self.actions[state])
return policy
def reset(self):
return (0, 0)
def is_terminal(self, s):
return s not in self.actions
def getNewState(self,state,action):
i, j = zip(state)
row = int(i[0])
column = int(j[0])
if action == 'U':
row -= 1
elif action == 'D':
row += 1
elif action == 'L':
column -= 1
elif action == 'R':
column += 1
return row,column
def chooseAction(self, state, policy, exploreRate):
if exploreRate > np.random.rand():
self.explored += 1
return np.random.choice(self.actions[state])
self.exploited += 1
return policy[state]
def greedyChoose(self, state, values):
actions = self.actions[state]
stateValues = []
for act in actions:
row,column=self.getNewState(state,act)
if (row, column) in values:
stateValues.append(values[(row, column)])
return actions[np.argmax(stateValues)]
def move(self, state, policy, exploreRate):
action = self.chooseAction(state, policy, exploreRate)
row,column=self.getNewState(state,action)
if (row, column) in self.rewards:
return (row, column),self.rewards[(row, column)]
return (row, column),-0.05
def printVaues(self,values):
line = ""
counter = 0
for item in values:
line += f" | {values[item]} | "
counter += 1
if counter > 3:
print(line)
print("--------------------------------")
counter = 0
line = ""
print(line)
print("----------------------------")
def printPolicy(self, policy):
line = ""
counter = 0
for item in policy:
line += f" | {policy[item]} | "
counter += 1
if counter > 3:
print(line)
print("----------------------------")
counter = 0
line = ""
print(line)
print("----------------------------")
enviroment = GridWorld()
policy = enviroment.getRandomPolicy()
# enviroment.printPolicy(policy)
#example optimal policy = {(0, 0): 'R', (0, 1): 'R', (0, 2): 'D', (0, 3): 'D', (1, 0): 'R', (1, 1): 'D', (1, 2): 'D', (1, 3): 'D',
# (2, 0): 'R', (2, 1): 'D', (2, 2): 'R', (2, 3): 'D', (3, 0): 'R', (3, 1): 'R', (3, 2): 'R'}
values = {}
for state in policy:
values[state] = 0
values[(3, 3)] = 2
for j in range(2001):
state = enviroment.reset()
stepCounts=0
while (not enviroment.is_terminal(state)) and (stepCounts<20):
nextState, reward = enviroment.move(state, policy, exploreRate=0.01)
values[state] = values[state]+ 0.1 * ((reward + (0.9 * values[nextState])- values[state]))
state=nextState
stepCounts+=1
if (j%10)==0:
for item in policy:
policy[item] = enviroment.greedyChoose(item, values)
# for state in values:
# values[state]=0
if (j%200)==0:
print(f"\n\n\n step:{j}")
# enviroment.printVaues(values)
enviroment.printPolicy(policy)
print(f"exploited:{enviroment.exploited} explored:{enviroment.explored}")