@ex.main
def my_main(_run, _config, _log):
# Setting the random seed throughout the modules
config = config_copy(_config)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
config['env_args']['seed'] = config["seed"]
# run the framework
run(_run, config, _log)
np , th seed → np.random & th.random 실행시 일정한 랜덤값을 반환하도록 해 공평하게 성능을 비교할 수 있도록 한다
if __name__ == '__main__':
params = deepcopy(sys.argv)
# Get the defaults from default.yaml
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
try:
config_dict = yaml.safe_load(f) # modified
except yaml.YAMLError as exc:
assert False, "default.yaml error: {}".format(exc)
# Load algorithm and env base configs
env_config = _get_config(params, "--env-config", "envs")
alg_config = _get_config(params, "--config", "algs")
# config_dict = {**config_dict, **env_config, **alg_config}
config_dict = recursive_dict_update(config_dict, env_config)
config_dict = recursive_dict_update(config_dict, alg_config)
# now add all the config to sacred
ex.add_config(config_dict)
# Save to disk by default for sacred
logger.info("Saving to FileStorageObserver in results/sacred.")
file_obs_path = os.path.join(results_path, "sacred")
ex.observers.append(FileStorageObserver.create(file_obs_path))
ex.run_commandline(params)
env, other configurations setting
def run(_run, _config, _log):
run_sequential(args=args, logger=logger)
def run_sequential(args, logger):
# Init runner so we can get env info
# runner == class EpisodeRunner
runner = r_REGISTRY[args.runner](args=args, logger=logger)
<runner 변수에 들어간 EpisodeRunner class 의 주요 기능>
참고) env_info 의 예시 - 2ss3z 의 경우임
{'state_shape': 120, 'obs_shape': 80, 'n_actions': 11, 'n_agents': 5, 'episode_limit': 120, 'agent_features': ['health', 'energy/cooldown', 'rel_x', 'rel_y', 'shield', 'type_0', 'type_1'], 'enemy_features': ['health', 'rel_x', 'rel_y', 'shield', 'type_0', 'type_1']}
# run_sequential continued
buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
preprocess=preprocess,
device="cpu" if args.buffer_cpu_only else args.device)
self.data = SN()
self.data.transition_data = {}
self.data.episode_data = {}
self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess)
def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess):
if preprocess is not None:
# ...
# vshape, dtype term are added for new dictionary key 'action_onehot'
assert "filled" not in scheme, '"filled" is a reserved key for masking.'
scheme.update({
"filled": {"vshape": (1,), "dtype": th.long},
}) # 'filled' key is used for data.transition_data
for field_key, field_info in scheme.items():
if isinstance(vshape, int):
vshape = (vshape,) # obs, and action space vector size are given as int
if group:
shape = (groups[group], *vshape) # (# of agents, *vshape)
# figure out about this part later...
if episode_const:
self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device)
else:
self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device)
episode buffer 초기화 과정에서 주목할 점
# run_sequential continued
# buffer.scheme is an updated version of scheme with
# 'filled', 'action_onehot'
mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)
def __init__(self, scheme, groups, args):
self.n_agents = args.n_agents
self.args = args
# return based on two options
# 1) obs_last_action // 2) obs_agent_id -> this is id embedding
input_shape = self._get_input_shape(scheme)
# all agents share network parameters
# only create one RNN agent and append id vector
self._build_agents(input_shape)
self.agent_output_type = args.agent_output_type
# class EpsilonGreedyActionSelector is assigned
self.action_selector = action_REGISTRY[args.action_selector](args)
self.hidden_states = None
def __init__(self, args):
self.args = args
self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
decay="linear")
self.epsilon = self.schedule.eval(0)
# run_sequential continued
runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)
# episode_runner.py
def setup(self, scheme, groups, preprocess, mac):
self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
preprocess=preprocess, device=self.args.device)
self.mac = mac
# run_sequential continued
# Learner
learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)
이 부분부터 하나의 episode를 돌리는 과정에 대해 설명 (runner.run 에 의해)
# run_sequential continued
while runner.t_env <= args.t_max:
# Run for a whole episode at a time
episode_batch = runner.run(test_mode=False)
buffer.insert_episode_batch(episode_batch)
# episode_runner.py
def run(self, test_mode=False):
self.reset() # env reset, create new buffer
terminated = False
episode_return = 0
**self.mac.init_hidden(batch_size=self.batch_size)**
# basic_controller.py
def init_hidden(self, batch_size):
self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav
# episode_runner.py continued
while not terminated:
pre_transition_data = {
"state": [self.env.get_state()],
"avail_actions": [self.env.get_avail_actions()],
"obs": [self.env.get_obs()]
}
smac/starcraft2.py at master · oxwhirl/smac
위의 github 에서 함수정보를 확인할 수 있다.