Zahra Rajabi
pymdptoolbox
Commits
50719db4
Commit
50719db4
authored
Mar 13, 2014
by
Steven Cordwell
Browse files
rewrite tictactoe.py
parent
ce04771e
No files found.
src/examples/tictactoe.py
View file @
50719db4
# * coding: utf8 *
# * coding: utf8 *
#import mdp
import
numpy
as
np
from
scipy.sparse
import
dok_matrix
as
spdok
def
str_base
(
num
,
base
,
numerals
=
'0123456789abcdefghijklmnopqrstuvwxyz'
):
from
mdptoolbox
import
mdp
if
base
<
2
or
base
>
len
(
numerals
):
raise
ValueError
(
"str_base: base must be between 2 and %i"
%
ACTIONS
=
9
len
(
numerals
))
STATES
=
3
**
ACTIONS
PLAYER
=
1
OPPONENT
=
2
WINS
=
([
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
],
[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
])
# The valid number of cells belonging to either the player or the opponent:
# (player, opponent)
OWNED_CELLS
=
((
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
),
(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
3
,
4
))
def
convertIndexToTuple
(
state
):
""""""
return
(
tuple
(
int
(
x
)
for
x
in
np
.
base_repr
(
state
,
3
,
9
)[

9
::]))
def
convertTupleToIndex
(
state
):
""""""
return
(
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
))
def
getLegalActions
(
state
):
""""""
return
(
tuple
(
x
for
x
in
range
(
ACTIONS
)
if
state
[
x
]
==
0
))
def
getTransitionAndRewardArrays
():
""""""
P
=
[
spdok
((
STATES
,
STATES
))
for
a
in
range
(
ACTIONS
)]
R
=
spdok
((
STATES
,
ACTIONS
))
# Naive approach, iterate through all possible combinations
for
a
in
range
(
ACTIONS
):
for
s
in
range
(
STATES
):
state
=
convertIndexToTuple
(
s
)
if
not
isValid
(
state
):
# There are no defined moves from an invalid state, so
# transition probabilities cannot be calculated. However,
# P must be a square stochastic matrix, so assign a
# probability of one to the invalid state transitioning
# back to itself.
P
[
a
][
s
,
s
]
=
1
# Reward is 0
else
:
s1
,
p
,
r
=
getTransitionProbabilities
(
state
,
a
)
P
[
a
][
s
,
s1
]
=
p
R
[
s
,
a
]
=
r
P
[
a
]
=
P
[
a
].
tocsr
()
R
=
R
.
tocsc
()
return
(
P
,
R
)
def
getTransitionProbabilities
(
state
,
action
):
"""
Parameters

state : tuple
The state
action : int
The action
if
num
==
0
:
Returns
return
'0'

s1, p, r : tuple of two lists and an int
s1 are the next states, p are the probabilities, and r is the reward
if
num
<
0
:
"""
sign
=
''
#assert isValid(state)
num
=

num
assert
0
<=
action
<
ACTIONS
if
not
isLegal
(
state
,
action
):
# If the action is illegal, then transition back to the same state but
# incur a high negative reward
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],

10
)
# Update the state with the action
state
=
list
(
state
)
state
[
action
]
=
PLAYER
if
isWon
(
state
,
PLAYER
):
# If the player's action is a winning move then transition to the
# winning state and receive a reward of 1.
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
1
)
elif
isDraw
(
state
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
0
)
# Now we search through the opponents moves, and calculate transition
# probabilities based on maximising the opponents chance of winning..
s1
=
[]
p
=
[]
legal_a
=
getLegalActions
(
state
)
for
a
in
legal_a
:
state
[
a
]
=
OPPONENT
# If the opponent is going to win, we assume that the winning move will
# be chosen:
if
isWon
(
state
,
OPPONENT
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],

1
)
elif
isDraw
(
state
):
s1
=
[
convertTupleToIndex
(
state
)]
return
(
s1
,
[
1
],
0
)
# Otherwise we assume the opponent will select a move with uniform
# probability across potential moves:
s1
.
append
(
convertTupleToIndex
(
state
))
p
.
append
(
1.0
/
len
(
legal_a
))
state
[
a
]
=
0
# During nonterminal play states the reward is 0.
return
(
s1
,
p
,
0
)
def
getReward
(
state
,
action
):
""""""
if
not
isLegal
(
state
,
action
):
return

100
state
=
list
(
state
)
state
[
action
]
=
PLAYER
if
isWon
(
state
,
PLAYER
):
return
1
elif
isWon
(
state
,
OPPONENT
):
return

1
else
:
else
:
sign
=
''
return
0
result
=
''
while
num
:
result
=
numerals
[
num
%
(
base
)]
+
result
num
//=
base
return
sign
+
result
def
isDraw
(
state
):
""""""
try
:
state
.
index
(
0
)
return
False
except
ValueError
:
return
True
class
TicTacToeMDP
(
object
):
def
isLegal
(
state
,
action
):
""""""
""""""
if
state
[
action
]
==
0
:
def
__init__
(
self
):
return
True
""""""
else
:
self
.
P
=
[
None
]
*
9
for
a
in
xrange
(
9
):
self
.
P
[
a
]
=
{}
self
.
R
=
{}
# some board states are equal, just rotations of other states
self
.
rotorder
=
[]
#self.rotorder.append([0, 1, 2, 3, 4, 5, 6, 7, 8])
self
.
rotorder
.
append
([
6
,
3
,
0
,
7
,
4
,
1
,
8
,
5
,
2
])
self
.
rotorder
.
append
([
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
])
self
.
rotorder
.
append
([
2
,
5
,
8
,
1
,
4
,
7
,
0
,
3
,
6
])
# The valid number of cells belonging to either the player or the
# opponent: (player, opponent)
self
.
nXO
=
((
0
,
0
),
(
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
),
(
0
,
1
),
(
1
,
2
),
(
2
,
3
),
(
3
,
4
))
# The winning positions
self
.
wins
=
([
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
],
[
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
],
[
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
,
1
],
[
1
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
])
def
rotate
(
self
,
state
):
#rotations = []
identity
=
[]
#rotations.append(state)
identity
.
append
(
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
))
for
k
in
range
(
3
):
#rotations.append(tuple(state[self.rotorder[k][kk]]
# for kk in xrange(9)))
# Convert the state from base 3 number to integer.
#identity.append(int("".join(str(x) for x in rotations[k + 1]), 3))
identity
.
append
(
int
(
""
.
join
(
str
(
state
[
self
.
rotorder
[
k
][
kk
]])
for
kk
in
xrange
(
9
)),
3
))
# return the rotation with the smallest identity number
#idx = identity.index(min(identity))
#return (identity[idx], rotations[idx])
return
min
(
identity
)
def
unrotate
(
self
,
move
,
rotation
):
rotation
=
1
# return the move
return
self
.
rotorder
[
rotation
][
move
]
def
isLegal
(
self
,
state
,
action
):
""""""
if
state
[
action
]
==
0
:
return
True
else
:
return
False
def
isWon
(
self
,
state
,
who
):
""""""
# Check to see if there are any wins
for
w
in
self
.
wins
:
S
=
sum
(
1
if
(
w
[
k
]
==
1
and
state
[
k
]
==
who
)
else
0
for
k
in
xrange
(
9
))
if
S
==
3
:
# We have a win
return
True
# There were no wins so return False
return
False
return
False
def
isWon
(
state
,
who
):
"""Test if a tictactoe game has been won.
def
isDraw
(
self
,
state
):
Assumes that the board is in a legal state.
""""""
Will test if the value 1 is in any winning combination.
try
:
state
.
index
(
0
)
return
False
except
ValueError
:
return
True
except
:
raise
def
isValid
(
self
,
state
):
"""
""""""
for
w
in
WINS
:
# S1 is the sum of the player's cells
S
=
sum
(
1
if
(
w
[
k
]
==
state
[
k
]
==
who
)
else
0
S1
=
sum
(
1
if
x
==
1
else
0
for
x
in
state
)
for
k
in
range
(
ACTIONS
))
# S2 is the sum of the opponent's cells
if
S
==
3
:
S2
=
sum
(
1
if
x
==
2
else
0
for
x
in
state
)
# We have a win
if
(
S1
,
S2
)
in
self
.
nXO
:
return
True
return
True
else
:
# There were no wins so return False
return
False
return
False
def
getReward
(
self
,
s
):
def
isValid
(
state
):
if
self
.
isWon
(
s
,
1
):
""""""
return
1
# S1 is the sum of the player's cells
elif
self
.
isWon
(
s
,
2
):
S1
=
sum
(
1
if
x
==
PLAYER
else
0
for
x
in
state
)
return

1
# S2 is the sum of the opponent's cells
else
:
S2
=
sum
(
1
if
x
==
OPPONENT
else
0
for
x
in
state
)
return
0
if
(
S1
,
S2
)
in
OWNED_CELLS
:
return
True
def
run
(
self
):
else
:
""""""
return
False
l
=
(
0
,
1
,
2
)
# Iterate through a generator of all the combinations
for
s
in
((
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
)
for
a0
in
l
for
a1
in
l
for
a2
in
l
for
a3
in
l
for
a4
in
l
for
a5
in
l
for
a6
in
l
for
a7
in
l
for
a8
in
l
):
if
self
.
isValid
(
s
):
s_idn
=
self
.
rotate
(
s
)
if
not
self
.
R
.
has_key
(
s_idn
):
self
.
R
[
s_idn
]
=
self
.
getReward
(
s
)
self
.
transition
(
s
)
# Convert P and R to ijv lists
# Iterate through up to the theorectically maxmimum value of s
for
s
in
xrange
(
int
(
'222211110'
,
3
)):
print
s
# return (P, R)
def
toTuple
(
self
,
state
):
""""""
state
=
str_base
(
state
,
3
)
state
=
''
.
join
(
'0'
for
x
in
range
(
9

len
(
state
)))
+
state
return
tuple
(
int
(
x
)
for
x
in
state
)
def
transition
(
self
,
state
):
""""""
#TODO: the state needs to be rotated before anything else is done!!!
idn_s
=
int
(
""
.
join
(
str
(
x
)
for
x
in
state
),
3
)
legal_a
=
[
x
for
x
in
xrange
(
9
)
if
state
[
x
]
==
0
]
for
a
in
legal_a
:
s
=
[
x
for
x
in
state
]
s
[
a
]
=
1
is_won
=
self
.
isWon
(
s
,
1
)
legal_m
=
[
x
for
x
in
xrange
(
9
)
if
s
[
x
]
==
0
]
for
m
in
legal_m
:
s_new
=
[
x
for
x
in
s
]
s_new
[
m
]
=
2
idn_s_new
=
self
.
rotate
(
s_new
)
if
not
self
.
P
[
a
].
has_key
((
idn_s
,
idn_s_new
)):
self
.
P
[
a
][(
idn_s
,
idn_s_new
)]
=
len
(
legal_m
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
P
,
R
=
TicTacToeMDP
().
run
()
P
,
R
=
getTransitionAndRewardArrays
()
#ttt = mdp.ValueIteration(P, R, 1)
ttt
=
mdp
.
ValueIteration
(
P
,
R
,
1
)
print
(
ttt
.
policy
)
