环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
训练模型后保存模型,比较简单,这里简单介绍。
import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path ## 设置过程文件的存储路径
## 构建算法
algo = config.build()
## 训练算法
for i in range(3):
result = algo.train()
print(f"episode_{i}")
## 保存模型
## 方法1: 保存到默认路径下
algo.save() ## 保存到默认路径下, 一般是: ~/ray_result 文件夹下, 或 C:\Users\xxx\ray_results\ 文件夹下
## 上面设置的 config.output 只用于保存一些过程文件,不能决定这里的存储位置
## 方法2: 保存到默认路径下,并返回保存路径
checkpoint_dir = algo.save().checkpoint.path
print(f"Checkpoint saved in directory {checkpoint_dir}")
## 方法3: 保存到指定路径下
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")