Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Dev2 #143

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Update Comment For Net
BOBSTK committed Dec 22, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 6efcd00c285d3d85adc73b4b2b747dde055fc681
3 changes: 2 additions & 1 deletion .vscode/settings.json
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"files.associations": {
"*.ipp": "cpp"
}
},
"python.pythonPath": "/usr/local/Caskroom/miniconda/base/envs/moRL/bin/python"
}
21 changes: 16 additions & 5 deletions elf/utils_elf.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -223,20 +223,22 @@ def __init__(self, GC, co, descriptions, use_numpy=False, gpu=None, params=dict(
gpu(int): gpu to use.
params(dict): additional parameters
'''


#self.isPrint = False
self._init_collectors(GC, co, descriptions, use_gpu=gpu is not None, use_numpy=use_numpy)
self.gpu = gpu
self.inputs_gpu = [ self.inputs[gids[0]].cpu2gpu(gpu=gpu) for gids in self.gpu2gid ] if gpu is not None else None
self.params = params
self._cb = { }


def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
num_games = co.num_games

total_batchsize = 0
for key, v in descriptions.items():
total_batchsize += v["batchsize"]

if co.num_collectors > 0:
num_recv_thread = co.num_collectors
else:
@@ -269,11 +271,11 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
for i in range(num_recv_thread):
group_id = GC.AddCollectors(batchsize, len(gpu2gid) - 1, timeout_usec, gstat)

input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy)
input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载输入Batch
input_batch.batchsize = batchsize
inputs.append(input_batch)
if reply is not None:
reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy)
reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载回复Batch
reply_batch.batchsize= batchsize
replies.append(reply_batch)
else:
@@ -298,6 +300,14 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False):
self.name2idx = name2idx
self.gid2gpu = gid2gpu
self.gpu2gid = gpu2gid
# if not self.isPrint:
# print("idx2name",self.idx2name)
# print("name2idx",self.name2idx)
# print("gid2gpu",self.gid2gpu)
# print("gpu2gid",self.gpu2gid)
# print("num_collectors: ",co.num_collectors)
# self.isPrint = True


def reg_has_callback(self, key):
return key in self.name2idx
@@ -311,6 +321,7 @@ def reg_callback_if_exists(self, key, cb):

def reg_callback(self, key, cb):
'''Set callback function for key
注册回调函数,有符合要求和数量的Batch到来时,调用对应的函数

Parameters:
key(str): the key used to register the callback function.
@@ -332,7 +343,7 @@ def _call(self, infos):
raise ValueError("info.gid[%d] is not in callback functions" % infos.gid)

if self._cb[infos.gid] is None:
return;
return

batchsize = len(infos.s)

9 changes: 5 additions & 4 deletions rlpytorch/sampler/sample_methods.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -33,10 +33,10 @@ def sample_with_check(probs, greedy=False):
'''
num_action = probs.size(1)
if greedy:
_, actions = probs.max(1)
_, actions = probs.max(1) # 贪婪算法,每次取概率最大的动作
return actions
while True:
actions = probs.multinomial(1)[:,0]
actions = probs.multinomial(1)[:,0] # 按照概率选择一个动作
cond1 = (actions < 0).sum()
cond2 = (actions >= num_action).sum()
if cond1 == 0 and cond2 == 0:
@@ -74,8 +74,9 @@ def sample_eps_with_check(probs, epsilon, greedy=False):
rej_p = probs.new().resize_(2)
rej_p[0] = 1 - epsilon
rej_p[1] = epsilon
# rej 按照概率取 0 或 1(batchsize次),取到1时(epsilon)表示此次不选择该动作并随机取样
rej = rej_p.multinomial(batchsize, replacement=True).byte()

# 随机取样
uniform_p = probs.new().resize_(num_action).fill_(1.0 / num_action)
uniform_sampling = uniform_p.multinomial(batchsize, replacement=True)
actions[rej] = uniform_sampling[rej]
@@ -110,7 +111,7 @@ def sample_multinomial(state_curr, args, node="pi", greedy=False):
return actions
else:
probs = state_curr[node].data
return sample_eps_with_check(probs, args.epsilon, greedy=greedy)
return sample_eps_with_check(probs, args.epsilon, greedy=greedy) # probs 0 False

def epsilon_greedy(state_curr, args, node="pi"):
''' epsilon greedy sampling
Empty file modified rlpytorch/sampler/sampler.py
100644 → 100755
Empty file.
8 changes: 8 additions & 0 deletions rlpytorch/trainer/trainer.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ def __init__(self, name="eval", stats=True, verbose=False, actor_name="actor"):
on_get_args = self._on_get_args,
child_providers = child_providers
)
self.isPrint = False

def _on_get_args(self, _):
if self.stats is not None and not self.stats.is_valid():
@@ -75,6 +76,8 @@ def actor(self, batch):

if self.sampler is not None:
reply_msg = self.sampler.sample(state_curr)
# if not self.isPrint:
# print("sampler reply: ",reply_msg)
else:
reply_msg = dict(pi=state_curr["pi"].data)

@@ -88,6 +91,11 @@ def actor(self, batch):
reply_msg["V"] = state_curr["V"].data

self.actor_count += 1
# if not self.isPrint:
# print("batch: ",batch)
# print("state_curr",state_curr)
# print("reply_msg",reply_msg)
# self.isPrint = True
return reply_msg

def episode_summary(self, i):
1 change: 1 addition & 0 deletions rts/engine/game_env.cc
100644 → 100755
Original file line number Diff line number Diff line change
@@ -121,6 +121,7 @@ bool GameEnv::RemoveUnit(const UnitId &id) {
return true;
}

// 找到最近的基地
UnitId GameEnv::FindClosestBase(PlayerId player_id) const {
// Find closest base. [TODO]: Not efficient here.
for (auto it = _units.begin(); it != _units.end(); ++it) {
Empty file modified rts/game_MC/game.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion rts/game_MC/gamedef.cc
100644 → 100755
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ bool GameDef::CheckAddUnit(RTSMap *_map, UnitType, const PointF& p) const{
}

void GameDef::GlobalInit() {
reg_engine();
reg_engine(); //
reg_engine_specific();
reg_minirts_specific();

19 changes: 15 additions & 4 deletions rts/game_MC/model.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ class Model_ActorCritic(Model):
def __init__(self, args):
super(Model_ActorCritic, self).__init__(args)
self._init(args)
#self.isPrint = False

def _init(self, args):
params = args.params
@@ -31,9 +32,11 @@ def _init(self, args):
linear_in_dim = last_num_channel
else:
linear_in_dim = last_num_channel * 25



self.linear_policy = nn.Linear(linear_in_dim, params["num_action"])
self.linear_value = nn.Linear(linear_in_dim, 1)
self.linear_policy = nn.Linear(linear_in_dim, params["num_action"]) # 策略函数
self.linear_value = nn.Linear(linear_in_dim, 1) # 价值函数

self.relu = nn.LeakyReLU(0.1)

@@ -49,13 +52,21 @@ def get_define_args():
def forward(self, x):
if self.params.get("model_no_spatial", False):
# Replace a complicated network with a simple retraction.
# Input: batchsize, channel, height, width
# Input: batchsize, channel, height, width Batch Object
xreduced = x["s"].sum(2).sum(3).squeeze()
xreduced[:, self.num_unit:] /= 20 * 20
output = self._var(xreduced)
else:
output = self.net(self._var(x["s"]))


#decide = self.decision(output)
#if not self.isPrint:
#print("x: ",x.batch)
# print("output: ",output)
# print("decision: ",decide)
# print("net: ",self)
#self.isPrint = True
#return decide
return self.decision(output)

def decision(self, h):
7 changes: 7 additions & 0 deletions rts/game_MC/python_options.h
100644 → 100755
Original file line number Diff line number Diff line change
@@ -22,6 +22,12 @@ struct GameState {
using State = GameState;
using Data = GameState;

/**
* 测试获取基地位置信息
* */
//float base_x, base_y;


int32_t id;
int32_t seq;
int32_t game_counter;
@@ -133,6 +139,7 @@ struct GameState {

// These fields are used to exchange with Python side using tensor interface.
DECLARE_FIELD(GameState, id, a, V, pi, last_r, s, rv, terminal, seq, game_counter, last_terminal, uloc, tloc, bt, ct, uloc_prob, tloc_prob, bt_prob, ct_prob, reduced_s, reduced_next_s);
//DECLARE_FIELD(GameState, id, a, V, pi, last_r, s, rv, terminal, seq, game_counter, last_terminal, uloc, tloc, bt, ct, uloc_prob, tloc_prob, bt_prob, ct_prob, reduced_s, reduced_next_s,base_x,base_y);
REGISTER_PYBIND_FIELDS(id);
};

3 changes: 3 additions & 0 deletions rts/game_MC/python_wrapper.cc
100644 → 100755
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ class GameContext {
}

void Start() {
std::cout<<"--------------GameContext Start-----------"<<std::endl;
_context->Start(
[this](int game_idx, const ContextOptions &context_options, const PythonOptions &options, const elf::Signal &signal, Comm *comm) {
auto params = this->GetParams();
@@ -88,7 +89,9 @@ class GameContext {
else if (key == "ct_prob") return EntryInfo(key, type_name, { max_unit_cmd, CmdInput::CI_NUM_CMDS });
else if (key == "reduced_s") return EntryInfo(key, type_name, { reduced_size });
else if (key == "reduced_next_s") return EntryInfo(key, type_name, { reduced_size });
else if (key == "base_x" || key == "base_y") return EntryInfo(key, type_name);


return EntryInfo();
}

8 changes: 8 additions & 0 deletions rts/game_MC/state_feature.cc
100644 → 100755
Original file line number Diff line number Diff line change
@@ -28,6 +28,14 @@ void MCExtractor::SaveInfo(const RTSState &s, PlayerId player_id, GameState *gs)
gs->terminal = s.env().GetTermination() ? 1 : 0;

gs->last_r = 0.0;

// 测试获取基地位置
// UnitId baseId = s.env().FindClosestBase(player_id);
//const Unit* base = s.env().GetUnit(s.env().FindClosestBase(player_id));
//gs->base_x = s.env().GetUnit(s.env().FindClosestBase(player_id))->GetPointF().x;
//gs->base_y = s.env().GetUnit(s.env().FindClosestBase(player_id))->GetPointF().y;
// 测试获取基地位置

int winner = s.env().GetWinnerId();
if (winner != INVALID) {
if (winner == player_id) gs->last_r = 1.0;
18 changes: 15 additions & 3 deletions rts/game_MC/trunk.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ def __init__(self, args, output1d=True):
super(MiniRTSNet, self).__init__(args)
self._init(args)
self.output1d = output1d
#self.isPrint = False

def _init(self, args):
self.m = args.params.get("num_planes_per_time_stamp", 13)
@@ -64,7 +65,13 @@ def get_define_args():
def forward(self, input):
# BN and LeakyReLU are from Wendy's code.
x = input.view(input.size(0), self.input_channel, self.mapy, self.mapx)

# if not self.isPrint:
# print("input size: ",input.size())
# print("input: ",input)
# print("x size:",x.size())
# print("x: ",x)
# print("Net",self)
# self.isPrint = True
counts = Counter()
for i in range(len(self.arch)):
if self.arch[i] == "c":
@@ -78,5 +85,10 @@ def forward(self, input):

if self.output1d:
x = x.view(x.size(0), -1)

return x

# if not self.isPrint:
# print("x",x)
# print("x.size ",x.size())
# self.isPrint = True

return x # 64 x 550
8 changes: 8 additions & 0 deletions train.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -45,7 +45,15 @@
trainer.setup(sampler=env["sampler"], mi=env["mi"], rl_method=env["method"])

GC.reg_callback("train", trainer.train)
# def train(batch):
# print(batch)
# import pdb
# pdb.set_trace()
# return trainer.train(batch)

# GC.reg_callback("train", train)
GC.reg_callback("actor", trainer.actor)

runner.setup(GC, episode_summary=trainer.episode_summary,
episode_start=trainer.episode_start)

2 changes: 1 addition & 1 deletion train_minirts.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

game=./rts/game_MC/game model=actor_critic model_file=./rts/game_MC/model python3 train.py --batchsize 128 --freq_update 1 --players "type=AI_NN,fs=50,args=backup/AI_SIMPLE|start/500|decay/0.99;type=AI_SIMPLE,fs=20" --num_games 1024 --tqdm --T 20 --additional_labels id,last_terminal --trainer_stats winrate --keys_in_reply V "$@"
game=./rts/game_MC/game model=actor_critic model_file=./rts/game_MC/model python3 train.py --batchsize 64 --freq_update 1 --players "type=AI_NN,fs=50,args=backup/AI_SIMPLE|start/500|decay/0.99;type=AI_SIMPLE,fs=20" --num_games 1024 --tqdm --T 20 --additional_labels id,last_terminal --trainer_stats winrate --keys_in_reply V "$@"