-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpartner.py
82 lines (64 loc) · 2.61 KB
/
partner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from abc import ABC, abstractmethod
from typing import Tuple
import torch as th
import numpy as np
from stable_baselines3 import PPO
class PartnerPolicy(ABC):
def __init__(self):
pass
@abstractmethod
def forward(self, obs, deterministic=True) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
pass
class Partner:
def __init__(self, policy : PartnerPolicy):
self.policy = policy
class PPOPartnerPolicy(PartnerPolicy):
def __init__(self, model_path):
super(PartnerPolicy, self).__init__()
self.model = PPO.load(model_path)
print("PPO Partner loaded successfully: %s" % model_path)
def forward(self, obs, deterministic=True):
return self.model.policy.forward(obs, partner_idx=0, deterministic=deterministic)
class BlocksPermutationPartnerPolicy(PartnerPolicy):
def __init__(self, perm, n=2):
super(PartnerPolicy, self).__init__()
self.perm = perm
self.n = n
self.action_index = [[i*n + j for j in range(n)] for i in range(n)]
def forward(self, obs, deterministic=True):
obs = obs[0]
assert(2*self.n**2+1 == len(obs))
goal_grid = obs[:self.n**2].reshape(self.n,self.n)
working_grid = obs[self.n**2:2*(self.n**2)].reshape(self.n,self.n)
turn = obs[-1]
r, c = self.get_red_block_position(working_grid, self.n, self.n)
#if r == None or turn >= 2:
if r == None or turn >= 2:
action = self.n**2+1 # pass turn
else:
action = self.perm[self.action_index[r][c]]
return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0])
def get_block_position(self, grid, r, c, target):
for i in range(r):
for j in range(c):
if grid[i][j] == target:
return i, j
return None, None
def get_blue_block_position(self, grid, r, c):
return self.get_block_position(grid, r, c, 3)
def get_red_block_position(self, grid, r, c):
return self.get_block_position(grid, r, c, 2)
class ArmsPartnerPolicy(PartnerPolicy):
def __init__(self, perm):
super(PartnerPolicy, self).__init__()
self.perm = th.tensor(perm)
def forward(self, obs, deterministic=True):
action = self.perm[obs]
return th.cat((action, action), dim=1), th.tensor([0.0]), th.tensor([0.0])
class LowRankPartnerPolicy(PartnerPolicy):
def __init__(self, n):
super(PartnerPolicy, self).__init__()
self.n = n
def forward(self, obs, deterministic=True):
action = self.n
return th.tensor([action]), th.tensor([0.0]), th.tensor([0.0])