0
0

How Tensorflow set device for each Operation ?

Robin Dong 发表于 2018年10月12日 11:41 | Hits: 434
Tag: machine learning | tensorflow

In Tensorflow, we only need to use snippet below to assign a device to a Operation:

with tf.device('/GPU:0'):
  ...
  result = tf.matmul(a, b)

How dose it implement? Let’s take a look.

There is a mechanism called‘context manager’in Python. For example, we can use it to add a wrapper for a few codes:

from contextlib import contextmanager

@contextmanager
def tag(name):
  print("[%s]" % name)
  yield
  print("[/%s]" % name)
  
with tag("robin"):
  print("what")
  print("is")
  print("nature's")

The result of running this script is:

[robin]
what
is
nature's
[/robin]

Function ‘tag()’ works like a decorator. It will do something before and after those codes laying under its ‘context’.

Tensorflow uses the same principle.

@tf_export("device")                                                       
def device(device_name_or_function):
...
  if context.executing_eagerly():                                          
    # TODO(agarwal): support device functions in EAGER mode.
    if callable(device_name_or_function):
      raise RuntimeError(
          "tf.device does not support functions when eager execution "
          "is enabled.")
    return context.device(device_name_or_function)
  else:
    return get_default_graph().device(device_name_or_function)

This will call class Graph’s function ‘device()’. Its implementation:

@tf_export("GraphKeys")
class GraphKeys(object):
...
  @tf_contextlib.contextmanager
  def device(self, device_name_or_function):
  ...
      self._add_device_to_stack(device_name_or_function, offset=2)
    try:
      yield
    finally:
      self._device_function_stack.pop_obj()

The key line is ‘self._add_device_to_stack()’. Context of ‘device’ will add device name into stack of python, and when developer create an Operation it will fetch device name from stack and set it to this Operation.
Let’s check the code routine of creating Operation:

@tf_export("GraphKeys")
class GraphKeys(object):
...
  def create_op(
      self,
      op_type,                                                             
      inputs,
      dtypes,  # pylint: disable=redefined-outer-name
      input_types=None,
      name=None,
      attrs=None,
      op_def=None,                                                         
      compute_shapes=True,                                                 
      compute_device=True):
  ...
    with self._mutation_lock():
      ret = Operation(
          node_def,
          self,
          inputs=inputs,
          output_types=dtypes,
          control_inputs=control_inputs,
          input_types=input_types,
          original_op=self._default_original_op,
          op_def=op_def)
      self._create_op_helper(ret, compute_device=compute_device)
    return ret

def _create_op_helper(self, op, compute_device=True):
  ...
    if compute_device:
      self._apply_device_functions(op)

def _apply_device_functions(self, op):
  ...
    for device_spec in self._device_function_stack.peek_objs():
      if device_spec.function is None:
        break
      op._set_device(device_spec.function(op))
    op._device_code_locations = self._snapshot_device_function_stack_metadata()

‘self._device_function_stack.peek_objs’ is where it peek the device name from stack.

原文链接: http://donghao.org/2018/10/12/how-tensorflow-set-device-for-each-operation/

0     0

评价列表(0)