mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
one_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepQLearningWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41 
52  const UpdaterType& updater,
53  const EnvironmentType& environment,
54  const TrainingConfig& config,
55  bool deterministic):
56  updater(updater),
57  #if ENS_VERSION_MAJOR >= 2
58  updatePolicy(NULL),
59  #endif
60  environment(environment),
61  config(config),
62  deterministic(deterministic),
63  pending(config.UpdateInterval())
64  { Reset(); }
65 
72  updater(other.updater),
73  #if ENS_VERSION_MAJOR >= 2
74  updatePolicy(NULL),
75  #endif
76  environment(other.environment),
77  config(other.config),
78  deterministic(other.deterministic),
79  steps(other.steps),
80  episodeReturn(other.episodeReturn),
81  pending(other.pending),
82  pendingIndex(other.pendingIndex),
83  network(other.network),
84  state(other.state)
85  {
86  #if ENS_VERSION_MAJOR >= 2
87  updatePolicy = new typename UpdaterType::template
88  Policy<arma::mat, arma::mat>(updater,
89  network.Parameters().n_rows,
90  network.Parameters().n_cols);
91  #endif
92 
93  Reset();
94  }
95 
102  updater(std::move(other.updater)),
103  #if ENS_VERSION_MAJOR >= 2
104  updatePolicy(NULL),
105  #endif
106  environment(std::move(other.environment)),
107  config(std::move(other.config)),
108  deterministic(std::move(other.deterministic)),
109  steps(std::move(other.steps)),
110  episodeReturn(std::move(other.episodeReturn)),
111  pending(std::move(other.pending)),
112  pendingIndex(std::move(other.pendingIndex)),
113  network(std::move(other.network)),
114  state(std::move(other.state))
115  {
116  #if ENS_VERSION_MAJOR >= 2
117  other.updatePolicy = NULL;
118 
119  updatePolicy = new typename UpdaterType::template
120  Policy<arma::mat, arma::mat>(updater,
121  network.Parameters().n_rows,
122  network.Parameters().n_cols);
123  #endif
124  }
125 
132  {
133  if (&other == this)
134  return *this;
135 
136  #if ENS_VERSION_MAJOR >= 2
137  delete updatePolicy;
138  #endif
139 
140  updater = other.updater;
141  environment = other.environment;
142  config = other.config;
143  deterministic = other.deterministic;
144  steps = other.steps;
145  episodeReturn = other.episodeReturn;
146  pending = other.pending;
147  pendingIndex = other.pendingIndex;
148  network = other.network;
149  state = other.state;
150 
151  #if ENS_VERSION_MAJOR >= 2
152  updatePolicy = new typename UpdaterType::template
153  Policy<arma::mat, arma::mat>(updater,
154  network.Parameters().n_rows,
155  network.Parameters().n_cols);
156  #endif
157 
158  Reset();
159 
160  return *this;
161  }
162 
169  {
170  if (&other == this)
171  return *this;
172 
173  #if ENS_VERSION_MAJOR >= 2
174  delete updatePolicy;
175  #endif
176 
177  updater = std::move(other.updater);
178  environment = std::move(other.environment);
179  config = std::move(other.config);
180  deterministic = std::move(other.deterministic);
181  steps = std::move(other.steps);
182  episodeReturn = std::move(other.episodeReturn);
183  pending = std::move(other.pending);
184  pendingIndex = std::move(other.pendingIndex);
185  network = std::move(other.network);
186  state = std::move(other.state);
187 
188  #if ENS_VERSION_MAJOR >= 2
189  other.updatePolicy = NULL;
190 
191  updatePolicy = new typename UpdaterType::template
192  Policy<arma::mat, arma::mat>(updater,
193  network.Parameters().n_rows,
194  network.Parameters().n_cols);
195  #endif
196 
197  return *this;
198  }
199 
204  {
205  #if ENS_VERSION_MAJOR >= 2
206  delete updatePolicy;
207  #endif
208  }
209 
214  void Initialize(NetworkType& learningNetwork)
215  {
216  #if ENS_VERSION_MAJOR == 1
217  updater.Initialize(learningNetwork.Parameters().n_rows,
218  learningNetwork.Parameters().n_cols);
219  #else
220  delete updatePolicy;
221 
222  updatePolicy = new typename UpdaterType::template
223  Policy<arma::mat, arma::mat>(updater,
224  learningNetwork.Parameters().n_rows,
225  learningNetwork.Parameters().n_cols);
226  #endif
227 
228  // Build local network.
229  network = learningNetwork;
230  }
231 
243  bool Step(NetworkType& learningNetwork,
244  NetworkType& targetNetwork,
245  size_t& totalSteps,
246  PolicyType& policy,
247  double& totalReward)
248  {
249  // Interact with the environment.
250  arma::colvec actionValue;
251  network.Predict(state.Encode(), actionValue);
252  ActionType action = policy.Sample(actionValue, deterministic);
253  StateType nextState;
254  double reward = environment.Sample(state, action, nextState);
255  bool terminal = environment.IsTerminal(nextState);
256 
257  episodeReturn += reward;
258  steps++;
259 
260  terminal = terminal || steps >= config.StepLimit();
261  if (deterministic)
262  {
263  if (terminal)
264  {
265  totalReward = episodeReturn;
266  Reset();
267  // Sync with latest learning network.
268  network = learningNetwork;
269  return true;
270  }
271  state = nextState;
272  return false;
273  }
274 
275  #pragma omp atomic
276  totalSteps++;
277 
278  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
279  pendingIndex++;
280 
281  if (terminal || pendingIndex >= config.UpdateInterval())
282  {
283  // Initialize the gradient storage.
284  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285  learningNetwork.Parameters().n_cols, arma::fill::zeros);
286  for (size_t i = 0; i < pending.size(); ++i)
287  {
288  TransitionType &transition = pending[i];
289 
290  // Compute the target state-action value.
291  arma::colvec actionValue;
292  #pragma omp critical
293  {
294  targetNetwork.Predict(
295  std::get<3>(transition).Encode(), actionValue);
296  };
297  double targetActionValue = actionValue.max();
298  if (terminal && i == pending.size() - 1)
299  targetActionValue = 0;
300  targetActionValue = std::get<2>(transition) +
301  config.Discount() * targetActionValue;
302 
303  // Compute the training target for current state.
304  network.Forward(std::get<0>(transition).Encode(), actionValue);
305  actionValue[std::get<1>(transition)] = targetActionValue;
306 
307  // Compute gradient.
308  arma::mat gradients;
309  network.Backward(actionValue, gradients);
310 
311  // Accumulate gradients.
312  totalGradients += gradients;
313  }
314 
315  // Clamp the accumulated gradients.
316  totalGradients.transform(
317  [&](double gradient)
318  { return std::min(std::max(gradient, -config.GradientLimit()),
319  config.GradientLimit()); });
320 
321  // Perform async update of the global network.
322  #if ENS_VERSION_MAJOR == 1
323  updater.Update(learningNetwork.Parameters(), config.StepSize(),
324  totalGradients);
325  #else
326  updatePolicy->Update(learningNetwork.Parameters(),
327  config.StepSize(), totalGradients);
328  #endif
329 
330  // Sync the local network with the global network.
331  network = learningNetwork;
332 
333  pendingIndex = 0;
334  }
335 
336  // Update global target network.
337  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
338  {
339  #pragma omp critical
340  { targetNetwork = learningNetwork; }
341  }
342 
343  policy.Anneal();
344 
345  if (terminal)
346  {
347  totalReward = episodeReturn;
348  Reset();
349  return true;
350  }
351  state = nextState;
352  return false;
353  }
354 
355  private:
359  void Reset()
360  {
361  steps = 0;
362  episodeReturn = 0;
363  pendingIndex = 0;
364  state = environment.InitialSample();
365  }
366 
368  UpdaterType updater;
369  #if ENS_VERSION_MAJOR >= 2
370  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
371  #endif
372 
374  EnvironmentType environment;
375 
377  TrainingConfig config;
378 
380  bool deterministic;
381 
383  size_t steps;
384 
386  double episodeReturn;
387 
389  std::vector<TransitionType> pending;
390 
392  size_t pendingIndex;
393 
395  NetworkType network;
396 
398  StateType state;
399 };
400 
401 } // namespace rl
402 } // namespace mlpack
403 
404 #endif
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
std::tuple< StateType, ActionType, double, StateType > TransitionType
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
double Discount() const
Get the discount rate for future reward.
Forward declaration of OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
size_t StepLimit() const
Get the maximum steps of each episode.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.