Source code for tianshou.core.value_function.action_value

from __future__ import absolute_import
import tensorflow as tf

from .base import ValueFunctionBase
from ..utils import identify_dependent_variables

__all__ = [
    'ActionValue',
    'DQN',
]


[docs]class ActionValue(ValueFunctionBase): """ Class for action values Q(s, a). The input of the value network is states and actions and the output of the value network is directly the Q-value of the input (state, action) pairs. :param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a Tensor of the value on the value head. :param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, a) in the network graph. :param action_placeholder: A :class:`tf.placeholder`. The action placeholder for a in Q(s, a) in the network graph. :param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN and DDPG, or just an old net to help optimization as in PPO. """ def __init__(self, network_callable, observation_placeholder, action_placeholder, has_old_net=False): self.observation_placeholder = observation_placeholder self.action_placeholder = action_placeholder self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder} self.has_old_net = has_old_net network_scope = 'network' net_old_scope = 'net_old' with tf.variable_scope(network_scope, reuse=tf.AUTO_REUSE): value_tensor = network_callable()[1] assert value_tensor is not None super(ActionValue, self).__init__(value_tensor, observation_placeholder=observation_placeholder) weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self.network_weights = identify_dependent_variables(self.value_tensor, weights) self._trainable_variables = [var for var in self.network_weights if var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)] # deal with target network if has_old_net: with tf.variable_scope(net_old_scope, reuse=tf.AUTO_REUSE): self.value_tensor_old = tf.squeeze(network_callable()[1]) old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=net_old_scope) # re-filter to rule out some edge cases old_weights = [var for var in old_weights if var.name[:len(net_old_scope)] == net_old_scope] self.network_old_weights = identify_dependent_variables(self.value_tensor_old, old_weights) assert len(self.network_weights) == len(self.network_old_weights) self.sync_weights_ops = [tf.assign(variable_old, variable) for (variable_old, variable) in zip(self.network_old_weights, self.network_weights)] else: self.sync_weights_ops = None @property def trainable_variables(self): """ The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s that affect the value. """ return set(self._trainable_variables)
[docs] def eval_value(self, observation, action, my_feed_dict={}): """ Evaluate value in minibatch using the current network. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param action: An array-like, of shape (batch_size,) + action_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} feed_dict.update(my_feed_dict) return sess.run(self.value_tensor, feed_dict=feed_dict)
[docs] def eval_value_old(self, observation, action, my_feed_dict={}): """ Evaluate value in minibatch using the old net. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param action: An array-like, of shape (batch_size,) + action_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} feed_dict.update(my_feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict)
[docs] def sync_weights(self): """ Sync the variables of the "old net" to be the same as the current network. """ if self.sync_weights_ops is not None: sess = tf.get_default_session() sess.run(self.sync_weights_ops)
[docs]class DQN(ValueFunctionBase): """ Class for the special action value function DQN. Instead of feeding s and a to the network to get a value, DQN feeds s to the network and gets at the last layer Q(s, \*) for all actions under this state. Still, as :class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete (and finite) action spaces. :param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a Tensor of Q(s, \*) on the value head. :param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, \*) in the network graph. :param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN and DDPG, or just an old net to help optimization as in PPO. """ def __init__(self, network_callable, observation_placeholder, has_old_net=False): self.observation_placeholder = observation_placeholder self.action_placeholder = action_placeholder = tf.placeholder(tf.int32, shape=(None,), name='action_value.DQN/action_placeholder') self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder} self.has_old_net = has_old_net network_scope = 'network' net_old_scope = 'net_old' with tf.variable_scope(network_scope, reuse=tf.AUTO_REUSE): value_tensor = network_callable()[1] assert value_tensor is not None self._value_tensor_all_actions = value_tensor self.num_actions = value_tensor.shape.as_list()[-1] batch_size = tf.shape(value_tensor)[0] batch_dim_index = tf.range(batch_size) indices = tf.stack([batch_dim_index, action_placeholder], axis=1) canonical_value_tensor = tf.gather_nd(value_tensor, indices) super(DQN, self).__init__(canonical_value_tensor, observation_placeholder=observation_placeholder) weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self.network_weights = identify_dependent_variables(self.value_tensor, weights) self._trainable_variables = [var for var in self.network_weights if var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)] # deal with target network if has_old_net: with tf.variable_scope(net_old_scope, reuse=tf.AUTO_REUSE): value_tensor = network_callable()[1] self.value_tensor_all_actions_old = value_tensor batch_size = tf.shape(value_tensor)[0] batch_dim_index = tf.range(batch_size) indices = tf.stack([batch_dim_index, action_placeholder], axis=1) canonical_value_tensor = tf.gather_nd(value_tensor, indices) self.value_tensor_old = canonical_value_tensor old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=net_old_scope) # re-filter to rule out some edge cases old_weights = [var for var in old_weights if var.name[:len(net_old_scope)] == net_old_scope] self.network_old_weights = identify_dependent_variables(self.value_tensor_old, old_weights) assert len(self.network_weights) == len(self.network_old_weights) self.sync_weights_ops = [tf.assign(variable_old, variable) for (variable_old, variable) in zip(self.network_old_weights, self.network_weights)] else: self.sync_weights_ops = None @property def trainable_variables(self): """ The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s that affect the value. """ return set(self._trainable_variables)
[docs] def eval_value(self, observation, action, my_feed_dict={}): """ Evaluate value Q(s, a) in minibatch using the current network. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param action: An array-like, of shape (batch_size,) + action_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} feed_dict.update(my_feed_dict) return sess.run(self.value_tensor, feed_dict=feed_dict)
[docs] def eval_value_old(self, observation, action, my_feed_dict={}): """ Evaluate value Q(s, a) in minibatch using the old net. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param action: An array-like, of shape (batch_size,) + action_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} feed_dict.update(my_feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict)
@property def value_tensor_all_actions(self): """The Tensor for Q(s, \*)""" return self._value_tensor_all_actions
[docs] def eval_value_all_actions(self, observation, my_feed_dict={}): """ Evaluate values Q(s, \*) in minibatch using the current network. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation} feed_dict.update(my_feed_dict) return sess.run(self._value_tensor_all_actions, feed_dict=feed_dict)
[docs] def eval_value_all_actions_old(self, observation, my_feed_dict={}): """ Evaluate values Q(s, \*) in minibatch using the old net. :param observation: An array-like, of shape (batch_size,) + observation_shape. :param my_feed_dict: Optional. A dict defaulting to empty. Specifies placeholders such as dropout and batch_norm except observation and action. :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. """ sess = tf.get_default_session() feed_dict = {self.observation_placeholder: observation} feed_dict.update(my_feed_dict) return sess.run(self.value_tensor_all_actions_old, feed_dict=feed_dict)
[docs] def sync_weights(self): """ Sync the variables of the "old net" to be the same as the current network. """ if self.sync_weights_ops is not None: sess = tf.get_default_session() sess.run(self.sync_weights_ops)