Skip to content

Commit

Permalink
Revert prev and register count in the state_dict so it is saved and l…
Browse files Browse the repository at this point in the history
…oaded with the module
  • Loading branch information
tasdep committed Jun 21, 2024
1 parent f2b721a commit 48af1c1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rsl_rl/modules/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, shape, eps=1e-2, until=None):
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
self.count = 0
self.register_buffer("count", torch.tensor(0, dtype=torch.long))

@property
def mean(self):
Expand Down
9 changes: 3 additions & 6 deletions rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
self.save_interval = self.cfg["save_interval"]
self.empirical_normalization = self.cfg["empirical_normalization"]
if self.empirical_normalization:
if train_cfg.get("resume") == True:
until = 0
else:
until = 1.0e8
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=until).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=until).to(self.device)
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else:
self.obs_normalizer = torch.nn.Identity() # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
Expand Down Expand Up @@ -264,6 +260,7 @@ def save(self, path, infos=None):
if self.empirical_normalization:
saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict()
saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict()

torch.save(saved_dict, path)

# Upload model to external logging service
Expand Down

0 comments on commit 48af1c1

Please sign in to comment.