Commit 94781489 authored by josef azzam's avatar josef azzam
Browse files

New Python Notebook for StableBL Testing and fixed plotting

parent f8bcc63c
Pipeline #60029 passed with stages
in 2 minutes and 42 seconds
......@@ -13,7 +13,7 @@ class JuPong2D_PPO2_Plot:
def __init__(self, env, output):
self.env_name = env
self.output = f"{output}/{self.env_name}"
self.folders = [f.path for f in os.scandir(self.output) if f.is_dir() and not f.name.startswith(".")]
self.folders = [f.path.replace("\\", "/") for f in os.scandir(self.output) if f.is_dir() and not f.name.startswith(".")]
def plot_ball_speed(self):
......@@ -47,10 +47,10 @@ class JuPong2D_PPO2_Plot:
scale_factors = np.mean(scale_factor_matrix, axis=0)
std_rewards_up = mean_rewards + std_vals
std_rewards_down = mean_rewards - std_vals
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed")
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed", colors="black")
plt.scatter(scale_factors, mean_rewards, label = bs_factor, s = 60)
plt.xlabel("Faktor der Ball-Geschwindigkeit", fontsize=20)
plt.ylabel("Mittl. Return-Wert mit\n Standardabweichung", fontsize=20)
plt.xlabel("Factor of the ball-speed", fontsize=20)
plt.ylabel("Mean Return-Values\n with std. deviation", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(loc = "lower right", fontsize=12)
......@@ -64,7 +64,6 @@ class JuPong2D_PPO2_Plot:
return-values with their standard deviation.
"""
pl_folders = [folder for folder in self.folders if folder.split("/")[-1].split("_")[-2] == "PaddleLength"]
fig = plt.figure()
for pl_folder in pl_folders:
pl_factor = float(pl_folder.split("_")[-1])
......@@ -89,10 +88,10 @@ class JuPong2D_PPO2_Plot:
scale_factors = np.mean(scale_factor_matrix, axis=0)
std_rewards_up = mean_rewards + std_vals
std_rewards_down = mean_rewards - std_vals
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed")
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed", colors="black")
plt.scatter(scale_factors, mean_rewards, label = pl_factor, s = 60)
plt.xlabel("Faktor der Paddle-Länge", fontsize=20)
plt.ylabel("Mittl. Return-Wert mit\n Standardabweichung", fontsize=20)
plt.xlabel("Factor of the paddle-length", fontsize=20)
plt.ylabel("Mean Return-Values\n with std. deviation", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(loc = "lower right", fontsize=12)
......@@ -131,10 +130,10 @@ class JuPong2D_PPO2_Plot:
scale_factors = np.mean(scale_factor_matrix, axis=0)
std_rewards_up = mean_rewards + std_vals
std_rewards_down = mean_rewards - std_vals
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed")
plt.vlines(scale_factors, std_rewards_down, std_rewards_up, zorder = -1, linestyle="dashed", colors="black")
plt.scatter(scale_factors, mean_rewards, label = ps_factor, s = 60)
plt.xlabel("Faktor der Paddle-Geschwindigkeit", fontsize=20)
plt.ylabel("Mittl. Return-Wert mit\n Standardabweichung", fontsize=20)
plt.xlabel("Factor of the paddle-speed", fontsize=20)
plt.ylabel("Mean Return-Values\n with std. deviation", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(loc = "lower right", fontsize=12)
......
......@@ -248,7 +248,7 @@ class JuPong2D_PPO2_Play:
self.return_arr[ind] = np.mean(return_val_arr)
print(self.return_arr)
with open(self.save_file, 'w') as my_file:
with open(self.save_file, 'w', newline='') as my_file:
writer = csv.writer(my_file)
writer.writerow(self.scale_factor_arr)
writer.writerow(self.return_arr)
......
%% Cell type:markdown id: tags:
# **Jupyter Notebook to test the library Stable Baselines in the Gym-Environment JuPong2D**
%% Cell type:code id: tags:
``` bash
``` python
# Checking current directory, it should be the GitLab-Repository
pwd
```
%% Cell type:code id: tags:
``` bash
``` python
# Activating the virtal environment, which should be installed before running the notebook
source lib/pong_deeprl/bin/activate
```
%% Cell type:code id: tags:
``` bash
``` python
# Lists all possible command options of the Stable Baselines training and testing script
deep_stablebl --help
```
%% Cell type:code id: tags:
``` bash
``` python
# Starts a simple training of the Gym-Environment JuPong2D
deep_stablebl --pl 0.5 --session 1 --tts 10000 --train-steps 3 results/stablebl_results
```
%% Cell type:code id: tags:
``` bash
``` python
# Testing of the trained model by playing in different Gym-Environments, which depend on the parameters
deep_stablebl --play --pl 0.5 --session 1 --play-steps 2 results/stablebl_results
```
%% Cell type:code id: tags:
``` bash
``` python
# Plots the training results in a PDF-file
plot_stablebl results/stablebl_results
```
......
%% Cell type:code id: tags:
``` python
from src.deeprl_lib.stablebaselines.jupong2d_ppo2 import *
from src.deeprl_lib.stablebaselines.jupong2d_plot_ppo2_data import *
def main():
results_folder = "results/stablebl_results"
env_name = "jupong2d-headless-0.4-v3"
session = 3
play_steps = 2
train_steps = 2
tts = 10000
train_runner = JuPong2D_PPO2(env_name, results_folder, train_steps, tts, session, paddle_length_factor=0.5)
train_runner.start_training()
play_runner = JuPong2D_PPO2_Play(env_name, results_folder, session, play_steps, paddle_length_factor=0.5)
play_runner.create_data()
ploter = JuPong2D_PPO2_Plot(env_name, results_folder)
ploter.plot_paddle_length()
ploter.plot_paddle_speed()
ploter.plot_ball_speed()
main()
```
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment