conda env decdiff
code/scripts/train 으로 넘어가는 kwargs
{'RUN.prefix': 'diffuser/default_inv/predict_epsilon_200_1000000.0/dropout_0.25/hopper-medium-expert-v2/100', 'seed': 100, 'returns_condition': True, 'predict_epsilon': True, 'n_diffusion_steps': 200, 'condition_dropout': 0.25, 'diffusion': 'models.GaussianInvDynDiffusion', 'n_train_steps': 1000000.0, 'dataset': 'hopper-medium-expert-v2', 'returns_scale': 400.0, 'RUN.job_counter': 1, 'RUN.job_name': 'predict_epsilon_200_1000000.0/dropout_0.25/hopper-medium-expert-v2/100'}
code/scripts/train.py
dataset_config = utils.Config(
Config.loader,
savepath='dataset_config.pkl',
env=Config.dataset,
horizon=Config.horizon,
normalizer=Config.normalizer,
preprocess_fns=Config.preprocess_fns,
use_padding=Config.use_padding,
max_path_length=Config.max_path_length,
include_returns=Config.include_returns,
returns_scale=Config.returns_scale,
discount=Config.discount,
termination_penalty=Config.termination_penalty,
)
[utils/config ] Config: <class 'diffuser.datasets.sequence.SequenceDataset'>
discount: 0.99
env: hopper-medium-expert-v2
horizon: 100
include_returns: True
max_path_length: 1000
normalizer: CDFNormalizer
preprocess_fns: []
returns_scale: 400.0
termination_penalty: -100
use_padding: True
[ utils/config ] Saved config to: dataset_config.pkl
scripts/train.py
render_config = utils.Config(
Config.renderer,
savepath='render_config.pkl',
env=Config.dataset,
)
[ utils/config ] Imported diffuser.utils:MuJoCoRenderer
[utils/config ] Config: <class 'diffuser.utils.rendering.MuJoCoRenderer'>
env: hopper-medium-expert-v2
[ utils/config ] Saved config to: render_config.pkl
scripts/train.py
dataset = dataset_config()
renderer = render_config()
diffuser/utils/config.py
def __call__(self, *args, **kwargs):
instance = self._class(*args, **kwargs, **self._dict)
if self._device:
instance = instance.to(self._device)
return instance
dataset 어떻게 불러오나..
diffuser/datasets/sequence.py
class SequenceDataset(torch.utils.data.Dataset):
def __init__(self, env='hopper-medium-replay', horizon=64,
normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000,
max_n_episodes=10000, termination_penalty=0, use_padding=True, discount=0.99, returns_scale=1000, include_returns=False):
**self.preprocess_fn = get_preprocess_fn(preprocess_fns, env)
self.env = env = load_environment(env)**
self.returns_scale = returns_scale
self.horizon = horizon
self.max_path_length = max_path_length
self.discount = discount
**self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]**
self.use_padding = use_padding
self.include_returns = include_returns
itr = sequence_dataset(env, self.preprocess_fn)
이어서 (diffuser/datasets/sequence.py)
**fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)**
for i, episode in enumerate(itr):
fields.add_path(episode)
fields.finalize()
diffuser/datasets/buffer.py
class ReplayBuffer:
def __init__(self, max_n_episodes, max_path_length, termination_penalty):
self._dict = {
'path_lengths': np.zeros(max_n_episodes, dtype=np.int32),
}
self._count = 0
self.max_n_episodes = max_n_episodes
self.max_path_length = max_path_length
self.termination_penalty = termination_penalty
diffuser/datasets/sequence.py
fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)
**for i, episode in enumerate(itr):**
fields.add_path(episode)
fields.finalize()
여기서 sequence_dataset 이 실행됨
diffuser/datasets/d4rl.py
def sequence_dataset(env, preprocess_fn):
"""
Returns an iterator through trajectories.
Args:
env: An OfflineEnv object.
dataset: An optional dataset to pass in for processing. If None,
the dataset will default to env.get_dataset()
**kwargs: Arguments to pass to env.get_dataset().
Returns:
An iterator through dictionaries with keys:
observations
actions
rewards
terminals
"""
dataset = get_dataset(env)
dataset = preprocess_fn(dataset)
N = dataset['rewards'].shape[0]
data_ = collections.defaultdict(list)
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = 'timeouts' in dataset
episode_step = 0
for i in range(N):
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
for k in dataset:
if 'metadata' in k: continue
data_[k].append(dataset[k][i])
if done_bool or final_timestep:
episode_step = 0
episode_data = {}
for k in data_:
episode_data[k] = np.array(data_[k])
if 'maze2d' in env.name:
episode_data = process_maze2d_episode(episode_data)
yield episode_data
data_ = collections.defaultdict(list)
episode_step += 1
next_observations, observations, rewards, terminals, timeouts,
diffuser/datasets/sequence.py
fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)
for i, episode in enumerate(itr):
**fields.add_path(episode)**
fields.finalize()
diffuser/datasets/buffer.py
def add_path(self, path):
path_length = len(path['observations'])
assert path_length <= self.max_path_length
if path['terminals'].any():
assert (path['terminals'][-1] == True) and (not path['terminals'][:-1].any())
## if first path added, set keys based on contents
self._add_keys(path)
## add tracked keys in path
for key in self.keys:
array = atleast_2d(path[key])
if key not in self._dict: self._allocate(key, array)
self._dict[key][self._count, :path_length] = array
## penalize early termination
if path['terminals'].any() and self.termination_penalty is not None:
assert not path['timeouts'].any(), 'Penalized a timeout episode for early termination'
self._dict['rewards'][self._count, path_length - 1] += self.termination_penalty
## record path length
self._dict['path_lengths'][self._count] = path_length
## increment path counter
self._count += 1
fields.finalize() 설명
diffuser/datasets/buffer.py
def _add_attributes(self):
'''
can access fields with `buffer.observations`
instead of `buffer['observations']`
'''
for key, val in self._dict.items():
setattr(self, key, val)
def finalize(self):
## remove extra slots
for key in self.keys + ['path_lengths']:
self._dict[key] = self._dict[key][:self._count]
self._add_attributes()
diffuser/datasets/sequence.py 이어서
**self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths'])**
self.indices = self.make_indices(fields.path_lengths, horizon)
self.observation_dim = fields.observations.shape[-1]
self.action_dim = fields.actions.shape[-1]
self.fields = fields
self.n_episodes = fields.n_episodes
self.path_lengths = fields.path_lengths
self.normalize()
diffuser/datasets/normalization.py
class DatasetNormalizer:
def __init__(self, dataset, normalizer, path_lengths=None):
dataset = flatten(dataset, path_lengths)
self.observation_dim = dataset['observations'].shape[1]
self.action_dim = dataset['actions'].shape[1]
if type(normalizer) == str:
normalizer = eval(normalizer)
self.normalizers = {}
for key, val in dataset.items():
try:
self.normalizers[key] = normalizer(val)
except:
print(f'[ utils/normalization ] Skipping {key} | {normalizer}')
# key: normalizer(val)
# for key, val in dataset.items()
diffuser/datasets/sequence.py 이어서
self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths'])
**self.indices = self.make_indices(fields.path_lengths, horizon)**
self.observation_dim = fields.observations.shape[-1]
self.action_dim = fields.actions.shape[-1]
self.fields = fields
self.n_episodes = fields.n_episodes
self.path_lengths = fields.path_lengths
**self.normalize()**
def normalize(self, keys=['observations', 'actions']):
'''
normalize fields that will be predicted by the diffusion model
'''
for key in keys:
array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1)
normed = self.normalizer(array, key)
self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
만들어진 field 의 log
[0] 차원은 episode 개수 , [1] 차원은 각 에피소드 당 최대 가능한 길이
actions: (3213, 1000, 3)
infos/action_log_probs: (3213, 1000, 1)
infos/qpos: (3213, 1000, 6)
infos/qvel: (3213, 1000, 6)
next_observations: (3213, 1000, 11)
observations: (3213, 1000, 11)
rewards: (3213, 1000, 1)
terminals: (3213, 1000, 1)
timeouts: (3213, 1000, 1)
normed_observations: (3213, 1000, 11)
normed_actions: (3213, 1000, 3)
scripts/train.py
dataset = dataset_config()
**renderer = render_config()**
if Config.diffusion == 'models.GaussianInvDynDiffusion':
model_config = utils.Config(
Config.model,
savepath='model_config.pkl',
horizon=Config.horizon,
transition_dim=observation_dim,
cond_dim=observation_dim,
dim_mults=Config.dim_mults,
returns_condition=Config.returns_condition,
dim=Config.dim,
condition_dropout=Config.condition_dropout,
calc_energy=Config.calc_energy,
device=Config.device,
)
diffusion_config = utils.Config(
Config.diffusion,
savepath='diffusion_config.pkl',
horizon=Config.horizon,
observation_dim=observation_dim,
action_dim=action_dim,
n_timesteps=Config.n_diffusion_steps,
loss_type=Config.loss_type,
clip_denoised=Config.clip_denoised,
predict_epsilon=Config.predict_epsilon,
hidden_dim=Config.hidden_dim,
ar_inv=Config.ar_inv,
train_only_inv=Config.train_only_inv,
## loss weighting
action_weight=Config.action_weight,
loss_weights=Config.loss_weights,
loss_discount=Config.loss_discount,
returns_condition=Config.returns_condition,
condition_guidance_w=Config.condition_guidance_w,
device=Config.device,
)
model log
Config: <class 'diffuser.models.temporal.TemporalUnet'>
calc_energy: False
cond_dim: 11
condition_dropout: 0.25
dim: 128
dim_mults: (1, 4, 8)
horizon: 100
returns_condition: True
transition_dim: 11
diffusion log
<class 'diffuser.models.diffusion.GaussianInvDynDiffusion'>
action_dim: 3
action_weight: 10
ar_inv: False
clip_denoised: True
condition_guidance_w: 1.2
hidden_dim: 256
horizon: 100
loss_discount: 1
loss_type: l2
loss_weights: None
n_timesteps: 200
observation_dim: 11
predict_epsilon: True
returns_condition: True
train_only_inv: False
trainer log
Config: <class 'diffuser.utils.training.Trainer'>
bucket: /home/aajay/weights/
ema_decay: 0.995
gradient_accumulate_every: 2
label_freq: 200000
log_freq: 1000
n_reference: 8
sample_freq: 10000
save_checkpoints: False
save_freq: 10000
save_parallel: False
train_batch_size: 32
train_device: cuda
train_lr: 0.0002
**model = model_config()**
diffusion = diffusion_config(model)
trainer = trainer_config(diffusion, dataset, renderer)
diffuser/models/temporal.py
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=128,
dim_mults=(1, 2, 4, 8),
returns_condition=False,
condition_dropout=0.1,
calc_energy=False,
kernel_size=5,
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
if calc_energy:
mish = False
act_fn = nn.SiLU()
else:
**mish = True
act_fn = nn.Mish()**
이어서
self.time_dim = dim
self.returns_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
act_fn,
nn.Linear(dim * 4, dim),
)
# -----------참고--------------
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# -----------참고--------------
self.returns_condition = returns_condition
self.condition_dropout = condition_dropout
self.calc_energy = calc_energy
**if self.returns_condition:
self.returns_mlp = nn.Sequential(
nn.Linear(1, dim),
act_fn,
nn.Linear(dim, dim * 4),
act_fn,
nn.Linear(dim * 4, dim),
)
self.mask_dist = Bernoulli(probs=1-self.condition_dropout)
embed_dim = 2*dim**
else:
embed_dim = dim
이어서 하기 전에
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, mish=True, n_groups=8):
super().__init__()
if mish:
act_fn = nn.Mish()
else:
act_fn = nn.SiLU()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
act_fn,
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5, mish=True):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size, mish),
Conv1dBlock(out_channels, out_channels, kernel_size, mish),
])
if mish:
act_fn = nn.Mish()
else:
act_fn = nn.SiLU()
self.time_mlp = nn.Sequential(
act_fn,
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = **nn.Conv1d(inp_channels, out_channels, 1)** \
if inp_channels != out_channels else nn.Identity()
def forward(self, x, t):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
` self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=kernel_size, mish=mish),
nn.Conv1d(dim, transition_dim, 1),
)
train.py 에서 model=TemporalUnet() init 끝나고
diffusion init 부분
class GaussianInvDynDiffusion(nn.Module):
def __init__(self, model, horizon, observation_dim, action_dim, n_timesteps=1000,
loss_type='l1', clip_denoised=False, predict_epsilon=True, hidden_dim=256,
action_weight=1.0, loss_discount=1.0, loss_weights=None, returns_condition=False,
condition_guidance_w=0.1, ar_inv=False, train_only_inv=False):
super().__init__()
self.horizon = horizon
self.observation_dim = observation_dim
self.action_dim = action_dim
self.transition_dim = observation_dim + action_dim
self.model = model
self.ar_inv = ar_inv
self.train_only_inv = train_only_inv
if self.ar_inv:
self.inv_model = ARInvModel(hidden_dim=hidden_dim, observation_dim=observation_dim, action_dim=action_dim)
**else:
self.inv_model = nn.Sequential(
nn.Linear(2 * self.observation_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, self.action_dim),
)**
self.returns_condition = returns_condition
self.condition_guidance_w = condition_guidance_w
이어서
betas = cosine_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = int(n_timesteps)
self.clip_denoised = clip_denoised
self.predict_epsilon = predict_epsilon
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
## log calculation clipped because the posterior variance
## is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1',
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
**loss_weights = self.get_loss_weights(loss_discount)**
self.loss_fn = Losses['state_l2'](loss_weights)
def get_loss_weights(self, discount):
'''
sets loss coefficients for trajectory
action_weight : float
coefficient on first action loss
discount : float
multiplies t^th timestep of trajectory loss by discount**t
weights_dict : dict
{ i: c } multiplies dimension i of observation loss by c
'''
self.action_weight = 1
dim_weights = torch.ones(self.observation_dim, dtype=torch.float32)
## decay loss with trajectory timestep: discount**t
discounts = discount ** torch.arange(self.horizon, dtype=torch.float)
discounts = discounts / discounts.mean()
loss_weights = torch.einsum('h,t->ht', discounts, dim_weights)
# Cause things are conditioned on t=0
**if self.predict_epsilon:
loss_weights[0, :] = 0**
return loss_weights
다시 돌아와서
self.loss_fn = Losses['state_l2'](loss_weights)
이렇게 선언되어있었는데
class WeightedStateLoss(nn.Module):
def __init__(self, weights):
super().__init__()
self.register_buffer('weights', weights)
def forward(self, pred, targ):
'''
pred, targ : tensor
[ batch_size x horizon x transition_dim ]
'''
loss = self._loss(pred, targ)
weighted_loss = (loss * self.weights).mean()
return weighted_loss, {'a0_loss': weighted_loss}
class WeightedStateL2(WeightedStateLoss):
def _loss(self, pred, targ):
return F.mse_loss(pred, targ, reduction='none')
scripts/train.py 의 trainer 로드 과정
utils/training.py
class Trainer(object):
def __init__(
self,
diffusion_model,
dataset,
renderer,
ema_decay=0.995,
train_batch_size=32,
train_lr=2e-5,
gradient_accumulate_every=2,
step_start_ema=2000,
update_ema_every=10,
log_freq=100,
sample_freq=1000,
save_freq=1000,
label_freq=100000,
save_parallel=False,
n_reference=8,
bucket=None,
train_device='cuda',
save_checkpoints=False,
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.save_checkpoints = save_checkpoints
self.step_start_ema = step_start_ema
self.log_freq = log_freq
self.sample_freq = sample_freq
self.save_freq = save_freq
self.label_freq = label_freq
self.save_parallel = save_parallel
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.dataset = dataset
self.dataloader = cycle(torch.utils.data.DataLoader(
self.dataset, batch_size=train_batch_size, num_workers=0, shuffle=True, pin_memory=True
))
self.dataloader_vis = cycle(torch.utils.data.DataLoader(
self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
))
self.renderer = renderer
self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr)
self.bucket = bucket
self.n_reference = n_reference
self.reset_parameters()
self.step = 0
self.device = train_device
scripts/train.py
# model, diffusion, trainer를 init하고나서
batch = utils.batchify(dataset[0], Config.device)
loss, _ = diffusion.loss(*batch)
loss.backward()
def __getitem__(self, idx, eps=1e-4):
path_ind, start, end = self.indices[idx]
observations = self.fields.normed_observations[path_ind, start:end]
actions = self.fields.normed_actions[path_ind, start:end]
conditions = self.get_conditions(observations)
trajectories = np.concatenate([actions, observations], axis=-1)
**if self.include_returns:
rewards = self.fields.rewards[path_ind, start:]
discounts = self.discounts[:len(rewards)]
returns = (discounts * rewards).sum()
returns = np.array([returns/self.returns_scale], dtype=np.float32)
batch = RewardBatch(trajectories, conditions, returns)**
else:
batch = Batch(trajectories, conditions)
return batch
utils.arrays.py
def batchify(batch, device):
scripts/train.py
batch = utils.batchify(dataset[0], Config.device)
loss, _ = diffusion.loss(*batch)
diffuser/models/diffusion.py
def loss(self, x, cond, returns=None):
batch_size = len(x)
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
**diffuse_loss, info = self.p_losses(x[:, :, self.action_dim:], cond, t, returns)**
# Calculating inv loss
x_t = x[:, :-1, self.action_dim:]
a_t = x[:, :-1, :self.action_dim]
x_t_1 = x[:, 1:, self.action_dim:]
x_comb_t = torch.cat([x_t, x_t_1], dim=-1)
x_comb_t = x_comb_t.reshape(-1, 2 * self.observation_dim)
a_t = a_t.reshape(-1, self.action_dim)
if self.ar_inv:
inv_loss = self.inv_model.calc_loss(x_comb_t, a_t)
else:
pred_a_t = self.inv_model(x_comb_t)
inv_loss = F.mse_loss(pred_a_t, a_t)
loss = (1 / 2) * (diffuse_loss + inv_loss)
return loss, info
def p_losses(self, x_start, cond, t, returns=None):
noise = torch.randn_like(x_start)
**x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)**
x_noisy = apply_conditioning(x_noisy, cond, 0)
x_recon = self.model(x_noisy, cond, t, returns)
if not self.predict_epsilon:
x_recon = apply_conditioning(x_recon, cond, 0)
assert noise.shape == x_recon.shape
if self.predict_epsilon:
loss, info = self.loss_fn(x_recon, noise)
else:
loss, info = self.loss_fn(x_recon, x_start)
return loss, info
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sample = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
return sample
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def p_losses(self, x_start, cond, t, returns=None):
noise = torch.randn_like(x_start)
****x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
**x_noisy = apply_conditioning(x_noisy, cond, 0)**
x_recon = self.model(x_noisy, cond, t, returns)
if not self.predict_epsilon:
x_recon = apply_conditioning(x_recon, cond, 0)
assert noise.shape == x_recon.shape
if self.predict_epsilon:
loss, info = self.loss_fn(x_recon, noise)
else:
loss, info = self.loss_fn(x_recon, x_start)
return loss, info
# models/helpers.py
def apply_conditioning(x, conditions, action_dim):
for t, val in conditions.items():
x[:, t, action_dim:] = val.clone()
return x
self.model 이 있으므로 Temporal UNet 에 대한 forward 과정을 살펴보자
models/temporal.py
def forward(self, x, cond, time, returns=None, use_dropout=True, force_dropout=False):
'''
x : [ batch x horizon x transition ]
returns : [batch x horizon]
'''
if self.calc_energy:
x_inp = x
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
if self.returns_condition:
assert returns is not None
returns_embed = self.returns_mlp(returns)
if use_dropout:
**mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
returns_embed = mask*returns_embed**
if force_dropout:
returns_embed = 0*returns_embed
t = torch.cat([t, returns_embed], dim=-1)
이어서
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
if self.calc_energy:
# Energy function
energy = ((x - x_inp)**2).mean()
grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True)
return grad[0]
else:
return x
def p_losses(self, x_start, cond, t, returns=None):
noise = torch.randn_like(x_start)
****x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_noisy = apply_conditioning(x_noisy, cond, 0)
x_recon = self.model(x_noisy, cond, t, returns)
if not self.predict_epsilon:
x_recon = apply_conditioning(x_recon, cond, 0)
assert noise.shape == x_recon.shape
if self.predict_epsilon:
**loss, info = self.loss_fn(x_recon, noise)**
else:
loss, info = self.loss_fn(x_recon, x_start)
return loss, info
models/diffusion.py
def loss(self, x, cond, returns=None):
if self.train_only_inv:
# Calculating inv loss
x_t = x[:, :-1, self.action_dim:]
a_t = x[:, :-1, :self.action_dim]
x_t_1 = x[:, 1:, self.action_dim:]
x_comb_t = torch.cat([x_t, x_t_1], dim=-1)
x_comb_t = x_comb_t.reshape(-1, 2 * self.observation_dim)
a_t = a_t.reshape(-1, self.action_dim)
if self.ar_inv:
loss = self.inv_model.calc_loss(x_comb_t, a_t)
info = {'a0_loss':loss}
else:
pred_a_t = self.inv_model(x_comb_t)
loss = F.mse_loss(pred_a_t, a_t)
info = {'a0_loss': loss}
else:
batch_size = len(x)
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
diffuse_loss, info = self.p_losses(x[:, :, self.action_dim:], cond, t, returns)
# Calculating inv loss
x_t = x[:, :-1, self.action_dim:]
a_t = x[:, :-1, :self.action_dim]
x_t_1 = x[:, 1:, self.action_dim:]
x_comb_t = torch.cat([x_t, x_t_1], dim=-1)
x_comb_t = x_comb_t.reshape(-1, 2 * self.observation_dim)
a_t = a_t.reshape(-1, self.action_dim)
if self.ar_inv:
inv_loss = self.inv_model.calc_loss(x_comb_t, a_t)
else:
**pred_a_t = self.inv_model(x_comb_t)
inv_loss = F.mse_loss(pred_a_t, a_t)**
loss = (1 / 2) * (diffuse_loss + inv_loss)
return loss, info
이제 scripts/train.py 의 trainer.train은 어떻게 동작하는지
위에서 dimension test한 것과 거의 동일하게 진행
eval 부분
넘겨지는 kwargs
{'RUN.prefix': 'diffuser/default_inv/predict_epsilon_200_1000000.0/dropout_0.25/hopper-medium-expert-v2/100', 'seed': 100, 'returns_condition': True, 'predict_epsilon': True, 'n_diffusion_steps': 200, 'condition_dropout': 0.25, 'diffusion': 'models.GaussianInvDynDiffusion', 'n_train_steps': 1000000.0, 'dataset': 'hopper-medium-expert-v2', 'returns_scale': 400.0, 'RUN.job_counter': 1, 'RUN.job_name': 'predict_epsilon_200_1000000.0/dropout_0.25/hopper-medium-expert-v2/100'}
dataset 과 config 동일하게 init
Config: <class 'diffuser.datasets.sequence.SequenceDataset'>
env: hopper-medium-expert-v2
horizon: 100
include_returns: True
max_path_length: 1000
normalizer: CDFNormalizer
preprocess_fns: []
returns_scale: 400.0
use_padding: True
diffuser config
Config: <class 'diffuser.models.diffusion.GaussianInvDynDiffusion'>
action_dim: 3
action_weight: 10
clip_denoised: True
condition_guidance_w: 1.2
hidden_dim: 256
horizon: 100
loss_discount: 1
loss_type: l2
loss_weights: None
n_timesteps: 200
observation_dim: 11
predict_epsilon: True
returns_condition: True
trainer config
Config: <class 'diffuser.utils.training.Trainer'>
bucket: /home/aajay/weights/
ema_decay: 0.995
gradient_accumulate_every: 2
label_freq: 200000
log_freq: 1000
n_reference: 8
sample_freq: 10000
save_freq: 10000
save_parallel: False
train_batch_size: 32
train_device: cuda
train_lr: 0.0002
model, diffusion, trainer load 동일
evaluate_inv_parallel.py
model = model_config()
diffusion = diffusion_config(model)
trainer = trainer_config(diffusion, dataset, renderer)
logger.print(utils.report_parameters(model), color='green')
trainer.step = state_dict['step']
trainer.model.load_state_dict(state_dict['model'])
trainer.ema_model.load_state_dict(state_dict['ema'])
num_eval = 10
device = Config.device
env_list = [gym.make(Config.dataset) for _ in range(num_eval)]
dones = [0 for _ in range(num_eval)]
episode_rewards = [0 for _ in range(num_eval)]
assert trainer.ema_model.condition_guidance_w == Config.condition_guidance_w
returns = to_device(Config.test_ret * torch.ones(num_eval, 1), device)
t = 0
obs_list = [env.reset()[None] for env in env_list]
obs = np.concatenate(obs_list, axis=0)
recorded_obs = [deepcopy(obs[:, None])]
while sum(dones) < num_eval:
obs = dataset.normalizer.normalize(obs, 'observations')
conditions = {0: to_torch(obs, device=device)}
**samples = trainer.ema_model.conditional_sample(conditions, returns=returns)**
obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)
obs_comb = obs_comb.reshape(-1, 2*observation_dim)
action = trainer.ema_model.inv_model(obs_comb)
samples = to_np(samples)
action = to_np(action)
action = dataset.normalizer.unnormalize(action, 'actions')
if t == 0:
normed_observations = samples[:, :, :]
observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
savepath = os.path.join('images', 'sample-planned.png')
renderer.composite(savepath, observations)
diffuser/models/diffusion.py
@torch.no_grad()
def conditional_sample(self, cond, returns=None, horizon=None, *args, **kwargs):
'''
conditions : [ (time, state), ... ]
'''
device = self.betas.device
batch_size = len(cond[0])
horizon = horizon or self.horizon
shape = (batch_size, horizon, self.observation_dim)
return self.p_sample_loop(shape, cond, returns, *args, **kwargs)
@torch.no_grad()
def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusion=False):
device = self.betas.device
batch_size = shape[0]
x = 0.5*torch.randn(shape, device=device)
x = apply_conditioning(x, cond, 0)
if return_diffusion: diffusion = [x] # false
progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
**x = self.p_sample(x, cond, timesteps, returns)**
x = apply_conditioning(x, cond, 0)
progress.update({'t': i})
if return_diffusion: diffusion.append(x)
progress.close()
if return_diffusion:
return x, torch.stack(diffusion, dim=1)
else:
return x
@torch.no_grad()
def p_sample(self, x, cond, t, returns=None):
b, *_, device = *x.shape, x.device
**model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t, returns=returns)**
noise = 0.5*torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
def p_mean_variance(self, x, cond, t, returns=None):
**if self.returns_condition:
# epsilon could be epsilon or x0 itself
epsilon_cond = self.model(x, cond, t, returns, use_dropout=False)
epsilon_uncond = self.model(x, cond, t, returns, force_dropout=True)
epsilon = epsilon_uncond + self.condition_guidance_w*(epsilon_cond - epsilon_uncond)**
else:
epsilon = self.model(x, cond, t)
t = t.detach().to(torch.int64)
x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)
if self.clip_denoised:
x_recon.clamp_(-1., 1.)
else:
assert RuntimeError()
**model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t)**
return model_mean, posterior_variance, posterior_log_variance
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
evaluate_inv_parallel.py
while sum(dones) < num_eval:
obs = dataset.normalizer.normalize(obs, 'observations')
conditions = {0: to_torch(obs, device=device)}
****samples = trainer.ema_model.conditional_sample(conditions, returns=returns)
obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)
obs_comb = obs_comb.reshape(-1, 2*observation_dim)
action = trainer.ema_model.inv_model(obs_comb)
samples = to_np(samples)
action = to_np(action)
action = dataset.normalizer.unnormalize(action, 'actions')
if t == 0:
normed_observations = samples[:, :, :]
observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
savepath = os.path.join('images', 'sample-planned.png')
renderer.composite(savepath, observations)