QMIX 코드리뷰

이두현·2024년 3월 17일
0

main.py

@ex.main
def my_main(_run, _config, _log):
    # Setting the random seed throughout the modules
    config = config_copy(_config)
    np.random.seed(config["seed"])
    th.manual_seed(config["seed"])
    config['env_args']['seed'] = config["seed"]

    # run the framework
    run(_run, config, _log)

np , th seed → np.random & th.random 실행시 일정한 랜덤값을 반환하도록 해 공평하게 성능을 비교할 수 있도록 한다

if __name__ == '__main__':
    params = deepcopy(sys.argv)

    # Get the defaults from default.yaml
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
        try:
            config_dict = yaml.safe_load(f) # modified
        except yaml.YAMLError as exc:
            assert False, "default.yaml error: {}".format(exc)

    # Load algorithm and env base configs
    env_config = _get_config(params, "--env-config", "envs")
    alg_config = _get_config(params, "--config", "algs")
    # config_dict = {**config_dict, **env_config, **alg_config}
    config_dict = recursive_dict_update(config_dict, env_config)
    config_dict = recursive_dict_update(config_dict, alg_config)

    # now add all the config to sacred
    ex.add_config(config_dict)

    # Save to disk by default for sacred
    logger.info("Saving to FileStorageObserver in results/sacred.")
    file_obs_path = os.path.join(results_path, "sacred")
    ex.observers.append(FileStorageObserver.create(file_obs_path))

    ex.run_commandline(params)

env, other configurations setting


run.py

def run(_run, _config, _log):
		run_sequential(args=args, logger=logger)
def run_sequential(args, logger):
		# Init runner so we can get env info
	  # runner == class EpisodeRunner
    runner = r_REGISTRY[args.runner](args=args, logger=logger)

<runner 변수에 들어간 EpisodeRunner class 의 주요 기능>

  • def setup : reset 시 필요한 빈 ReplayBuffer 를 생성 // mac(Multi Agent Controller) 객체 생성
  • def get_env_info : class 내부에 env 객체를 갖고 있음
  • def run : (1:N) 으로 env 시뮬레이션, batch sample 을 통한 NN 업데이트 과정을 진행

참고) env_info 의 예시 - 2ss3z 의 경우임

{'state_shape': 120, 'obs_shape': 80, 'n_actions': 11, 'n_agents': 5, 'episode_limit': 120, 'agent_features': ['health', 'energy/cooldown', 'rel_x', 'rel_y', 'shield', 'type_0', 'type_1'], 'enemy_features': ['health', 'rel_x', 'rel_y', 'shield', 'type_0', 'type_1']}

		# run_sequential continued
		buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                          preprocess=preprocess,
                          device="cpu" if args.buffer_cpu_only else args.device)

episode_buffer.py

		self.data = SN()
    self.data.transition_data = {}
    self.data.episode_data = {}
    self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess)

def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess):
		if preprocess is not None:
		# ...
		# vshape, dtype term are added for new dictionary key 'action_onehot'
		assert "filled" not in scheme, '"filled" is a reserved key for masking.'
		scheme.update({
            "filled": {"vshape": (1,), "dtype": th.long},
    }) # 'filled' key is used for data.transition_data
		
		for field_key, field_info in scheme.items():
				if isinstance(vshape, int):
            vshape = (vshape,) # obs, and action space vector size are given as int
				if group:
						shape = (groups[group], *vshape) # (# of agents, *vshape)
				# figure out about this part later...
				if episode_const:
            self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device)
				else:
						self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device)

episode buffer 초기화 과정에서 주목할 점

  • data.transition_data 에 사용될 filled 라는 마스크에 대한 정의
  • 앞서 주어진 obeservation 이나 action size 가 int 였기 때문에 (vshape, ) 으로 dimension up 시킴
  • data.episode_data 에는 (batch_size, *shape) 모양의 초기화를
  • data.transition_data 에는 (batch_size, max_seq_length, *shape) 모양의 초기화를 진행
		# run_sequential continued
		# buffer.scheme is an updated version of scheme with 
		# 'filled', 'action_onehot'
		mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)

basic_controller.py

def __init__(self, scheme, groups, args):
        self.n_agents = args.n_agents
        self.args = args
				# return based on two options
				# 1) obs_last_action // 2) obs_agent_id -> this is id embedding
        input_shape = self._get_input_shape(scheme)
				# all agents share network parameters
				# only create one RNN agent and append id vector
        self._build_agents(input_shape)
        self.agent_output_type = args.agent_output_type
				# class EpsilonGreedyActionSelector is assigned
        self.action_selector = action_REGISTRY[args.action_selector](args)

        self.hidden_states = None

action_selectors.py

def __init__(self, args):
        self.args = args

        self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
                                              decay="linear")
        self.epsilon = self.schedule.eval(0)
  • self.action_selector 에는 다음과 같은 class의 instance 가 들어가며 decay 가 linear 할 경우 우리가 아는 식, decay 가 exp 할 경우 아래와 같은 식으로 감소한다
		# run_sequential continued
		runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

# episode_runner.py
def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac
  • 새로운 empty batch를 할당하는데 replay buffer 개념이 아닌 episode_runner에서는 한 episode를 terminate 까지 돌리고 결과를 반환하는 것이 목표이므로 이를 담을 empty buffer를 할당하는 것
  • 또한 mac(Multi-agent Controller)를 할당해 state 가 결정된 경우 다음 action을 얻기 위한 객체를만듦을 알 수 있다
		# run_sequential continued
		# Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)
  • learner 에는 class QLearner 를 할당하고 self.mixer 멤버 변수가 존재
  • self.mixer에는 논문에 소개된 q 값을 mixing 하는 NN이 정의되어있고 QLearner는 이 네트워크를 학습시키기 위한 loss등을 계산하는 함수 존재

이 부분부터 하나의 episode를 돌리는 과정에 대해 설명 (runner.run 에 의해)

# run_sequential continued
while runner.t_env <= args.t_max:
        # Run for a whole episode at a time
        episode_batch = runner.run(test_mode=False)
        buffer.insert_episode_batch(episode_batch)

# episode_runner.py
def run(self, test_mode=False):
        self.reset() # env reset, create new buffer
				terminated = False
        episode_return = 0
        **self.mac.init_hidden(batch_size=self.batch_size)**
# basic_controller.py
def init_hidden(self, batch_size):
				self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
  • agent.init_hidden은 RNN agent의 h_t 를 (1, hidden_vector) 크기의 영벡터로 반환하는데 이를 다시 batch_size 개의 (self.n_agents, hidden_vector) 크기로 확장시켜 mac의 hidden_state 멤버변수에 저장
# episode_runner.py continued
while not terminated:

            pre_transition_data = {
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "obs": [self.env.get_obs()]
            }
  • env 는 pip install 되어있기 때문에

smac/starcraft2.py at master · oxwhirl/smac

위의 github 에서 함수정보를 확인할 수 있다.

  • state : global_state 반환
  • avail_actions: 각 agent 별 가능한 action을 모두 list 에 담아 반환
  • obs : 각 agent의 관찰 결과를 list에 담아 반환
profile
0100101

0개의 댓글