Skip to content

Commit 2923c69

Browse files
author
Kami-code
committed
[#5] add evaluate_policy.py
1 parent 1b4a834 commit 2923c69

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ python examples/visualize_policy.py --task_name=laptop --checkpoint_path assets/
7474

7575
`use_test_set`: flag to determine evaluating with seen or unseen instances
7676

77+
### Example for Evaluating Policy
78+
79+
```bash
80+
python examples/evaluate_policy.py --task_name=laptop --checkpoint_path assets/rl_checkpoints/laptop.zip --eval_per_instance 10
81+
```
82+
83+
`task_name`: name of environment [`faucet`, `laptop`, `bucket`, `toilet`]
84+
85+
`use_test_set`: flag to determine evaluating with seen or unseen instances
86+
7787
### Example for Training RL Agent
7888

7989
```bash
@@ -97,7 +107,7 @@ python3 examples/train.py --n 100 --workers 10 --iter 5000 --lr 0.0001 &&
97107

98108
`extractor_name`: different PointNet architectures [`smallpn`, `meduimpn`, `largepn`]
99109

100-
`pretrain_path`: path to downloaded pretrained model. [Default: `None`
110+
`pretrain_path`: path to downloaded pretrained model. [Default: `None`]
101111

102112
## Bibtex
103113

examples/evaluate_policy.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import sys
3+
4+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5+
6+
import argparse
7+
from dexart.env.task_setting import TRAIN_CONFIG, RANDOM_CONFIG
8+
import numpy as np
9+
from dexart.env.create_env import create_env
10+
from stable_baselines3 import PPO
11+
from examples.train import get_3d_policy_kwargs
12+
from tqdm import tqdm
13+
14+
if __name__ == "__main__":
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--task_name', type=str, required=True)
17+
parser.add_argument('--checkpoint_path', type=str, required=True)
18+
parser.add_argument('--eval_per_instance', type=int, default=10)
19+
parser.add_argument('--use_test_set', dest='use_test_set', action='store_true', default=False)
20+
args = parser.parse_args()
21+
task_name = args.task_name
22+
use_test_set = args.use_test_set
23+
checkpoint_path = args.checkpoint_path
24+
25+
if use_test_set:
26+
indeces = TRAIN_CONFIG[task_name]['unseen']
27+
print(f"using unseen instances {indeces}")
28+
else:
29+
indeces = TRAIN_CONFIG[task_name]['seen']
30+
print(f"using seen instances {indeces}")
31+
32+
rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
33+
rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
34+
env = create_env(task_name=task_name,
35+
use_visual_obs=True,
36+
use_gui=True,
37+
is_eval=True,
38+
pc_noise=True,
39+
pc_seg=True,
40+
index=indeces,
41+
img_type='robot',
42+
rand_pos=rand_pos,
43+
rand_degree=rand_degree)
44+
45+
policy = PPO.load(checkpoint_path, env, 'cuda:0',
46+
policy_kwargs=get_3d_policy_kwargs(extractor_name='smallpn'),
47+
check_obs_space=False, force_load=True)
48+
49+
eval_instances = len(env.instance_list)
50+
eval_per_instance = args.eval_per_instance
51+
success_list = list()
52+
reward_list = list()
53+
54+
with tqdm(total=eval_per_instance * eval_instances) as pbar:
55+
for _ in range(eval_per_instance):
56+
for _ in range(eval_instances):
57+
obs = env.reset()
58+
eval_success = False
59+
reward_sum = 0
60+
for j in range(env.horizon):
61+
if isinstance(obs, dict):
62+
for key, value in obs.items():
63+
obs[key] = value[np.newaxis, :]
64+
else:
65+
obs = obs[np.newaxis, :]
66+
action = policy.predict(observation=obs, deterministic=True)[0]
67+
obs, reward, done, _ = env.step(action)
68+
reward_sum += reward
69+
if env.is_eval_done:
70+
eval_success = True
71+
if done:
72+
break
73+
reward_list.append(reward_sum)
74+
success_list.append(int(eval_success))
75+
pbar.update(1)
76+
print('reward_mean = ', np.mean(reward), 'success rate = ', np.mean(success_list))

0 commit comments

Comments
 (0)