1*b6fb3261SAndroid Build Coastguard Worker/* 2*b6fb3261SAndroid Build Coastguard WorkerCopyright 2018 The TensorFlow Authors. All Rights Reserved. 3*b6fb3261SAndroid Build Coastguard Worker 4*b6fb3261SAndroid Build Coastguard WorkerLicensed under the Apache License, Version 2.0 (the "License"); 5*b6fb3261SAndroid Build Coastguard Workeryou may not use this file except in compliance with the License. 6*b6fb3261SAndroid Build Coastguard WorkerYou may obtain a copy of the License at 7*b6fb3261SAndroid Build Coastguard Worker 8*b6fb3261SAndroid Build Coastguard Worker http://www.apache.org/licenses/LICENSE-2.0 9*b6fb3261SAndroid Build Coastguard Worker 10*b6fb3261SAndroid Build Coastguard WorkerUnless required by applicable law or agreed to in writing, software 11*b6fb3261SAndroid Build Coastguard Workerdistributed under the License is distributed on an "AS IS" BASIS, 12*b6fb3261SAndroid Build Coastguard WorkerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*b6fb3261SAndroid Build Coastguard WorkerSee the License for the specific language governing permissions and 14*b6fb3261SAndroid Build Coastguard Workerlimitations under the License. 15*b6fb3261SAndroid Build Coastguard Worker*/ 16*b6fb3261SAndroid Build Coastguard Worker 17*b6fb3261SAndroid Build Coastguard Workerpackage tensorflow 18*b6fb3261SAndroid Build Coastguard Worker 19*b6fb3261SAndroid Build Coastguard Worker// #include <stdlib.h> 20*b6fb3261SAndroid Build Coastguard Worker// #include "tensorflow/c/c_api.h" 21*b6fb3261SAndroid Build Coastguard Worker// #include "tensorflow/c/eager/c_api.h" 22*b6fb3261SAndroid Build Coastguard Workerimport "C" 23*b6fb3261SAndroid Build Coastguard Workerimport ( 24*b6fb3261SAndroid Build Coastguard Worker "fmt" 25*b6fb3261SAndroid Build Coastguard Worker "runtime" 26*b6fb3261SAndroid Build Coastguard Worker) 27*b6fb3261SAndroid Build Coastguard Worker 28*b6fb3261SAndroid Build Coastguard Worker// ContextOptions contains configuration information for a session 29*b6fb3261SAndroid Build Coastguard Workertype ContextOptions struct { 30*b6fb3261SAndroid Build Coastguard Worker // Config is a binary-serialized representation of the 31*b6fb3261SAndroid Build Coastguard Worker // tensorflow.ConfigProto protocol message 32*b6fb3261SAndroid Build Coastguard Worker // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). 33*b6fb3261SAndroid Build Coastguard Worker Config []byte 34*b6fb3261SAndroid Build Coastguard Worker 35*b6fb3261SAndroid Build Coastguard Worker // Sets the default execution mode 36*b6fb3261SAndroid Build Coastguard Worker Async bool 37*b6fb3261SAndroid Build Coastguard Worker} 38*b6fb3261SAndroid Build Coastguard Worker 39*b6fb3261SAndroid Build Coastguard Worker// c converts the ContextOptions to the C API's TF_ContextOptions. 40*b6fb3261SAndroid Build Coastguard Worker// Caller takes ownership of returned object. 41*b6fb3261SAndroid Build Coastguard Workerfunc (o *ContextOptions) c() (*C.TFE_ContextOptions, error) { 42*b6fb3261SAndroid Build Coastguard Worker opt := C.TFE_NewContextOptions() 43*b6fb3261SAndroid Build Coastguard Worker if o == nil { 44*b6fb3261SAndroid Build Coastguard Worker return opt, nil 45*b6fb3261SAndroid Build Coastguard Worker } 46*b6fb3261SAndroid Build Coastguard Worker 47*b6fb3261SAndroid Build Coastguard Worker if sz := len(o.Config); sz > 0 { 48*b6fb3261SAndroid Build Coastguard Worker status := newStatus() 49*b6fb3261SAndroid Build Coastguard Worker cConfig := C.CBytes(o.Config) 50*b6fb3261SAndroid Build Coastguard Worker C.TFE_ContextOptionsSetConfig(opt, cConfig, C.size_t(sz), status.c) 51*b6fb3261SAndroid Build Coastguard Worker C.free(cConfig) 52*b6fb3261SAndroid Build Coastguard Worker if err := status.Err(); err != nil { 53*b6fb3261SAndroid Build Coastguard Worker C.TFE_DeleteContextOptions(opt) 54*b6fb3261SAndroid Build Coastguard Worker return nil, fmt.Errorf("invalid ContextOptions.Config: %v", err) 55*b6fb3261SAndroid Build Coastguard Worker } 56*b6fb3261SAndroid Build Coastguard Worker } 57*b6fb3261SAndroid Build Coastguard Worker 58*b6fb3261SAndroid Build Coastguard Worker var async uint8 59*b6fb3261SAndroid Build Coastguard Worker if o.Async { 60*b6fb3261SAndroid Build Coastguard Worker async = 1 61*b6fb3261SAndroid Build Coastguard Worker } 62*b6fb3261SAndroid Build Coastguard Worker C.TFE_ContextOptionsSetAsync(opt, C.uchar(async)) 63*b6fb3261SAndroid Build Coastguard Worker 64*b6fb3261SAndroid Build Coastguard Worker return opt, nil 65*b6fb3261SAndroid Build Coastguard Worker} 66*b6fb3261SAndroid Build Coastguard Worker 67*b6fb3261SAndroid Build Coastguard Worker// Context for executing operations eagerly. 68*b6fb3261SAndroid Build Coastguard Worker// 69*b6fb3261SAndroid Build Coastguard Worker// A Context allows operations to be executed immediately. It encapsulates 70*b6fb3261SAndroid Build Coastguard Worker// information such as the available devices, resource manager etc. It also 71*b6fb3261SAndroid Build Coastguard Worker// allows the user to configure execution using a ConfigProto, as they can 72*b6fb3261SAndroid Build Coastguard Worker// configure a Session when executing a Graph. 73*b6fb3261SAndroid Build Coastguard Workertype Context struct { 74*b6fb3261SAndroid Build Coastguard Worker c *C.TFE_Context 75*b6fb3261SAndroid Build Coastguard Worker} 76*b6fb3261SAndroid Build Coastguard Worker 77*b6fb3261SAndroid Build Coastguard Worker// NewContext creates a new context for eager execution. 78*b6fb3261SAndroid Build Coastguard Worker// options may be nil to use the default options. 79*b6fb3261SAndroid Build Coastguard Workerfunc NewContext(options *ContextOptions) (*Context, error) { 80*b6fb3261SAndroid Build Coastguard Worker status := newStatus() 81*b6fb3261SAndroid Build Coastguard Worker cOpt, err := options.c() 82*b6fb3261SAndroid Build Coastguard Worker if err != nil { 83*b6fb3261SAndroid Build Coastguard Worker return nil, err 84*b6fb3261SAndroid Build Coastguard Worker } 85*b6fb3261SAndroid Build Coastguard Worker defer C.TFE_DeleteContextOptions(cOpt) 86*b6fb3261SAndroid Build Coastguard Worker cContext := C.TFE_NewContext(cOpt, status.c) 87*b6fb3261SAndroid Build Coastguard Worker if err := status.Err(); err != nil { 88*b6fb3261SAndroid Build Coastguard Worker return nil, err 89*b6fb3261SAndroid Build Coastguard Worker } 90*b6fb3261SAndroid Build Coastguard Worker 91*b6fb3261SAndroid Build Coastguard Worker c := &Context{c: cContext} 92*b6fb3261SAndroid Build Coastguard Worker runtime.SetFinalizer(c, (*Context).finalizer) 93*b6fb3261SAndroid Build Coastguard Worker return c, nil 94*b6fb3261SAndroid Build Coastguard Worker} 95*b6fb3261SAndroid Build Coastguard Worker 96*b6fb3261SAndroid Build Coastguard Workerfunc (c *Context) finalizer() { 97*b6fb3261SAndroid Build Coastguard Worker C.TFE_DeleteContext(c.c) 98*b6fb3261SAndroid Build Coastguard Worker} 99*b6fb3261SAndroid Build Coastguard Worker 100*b6fb3261SAndroid Build Coastguard Worker// ListDevices returns the list of devices associated with a Context. 101*b6fb3261SAndroid Build Coastguard Workerfunc (c *Context) ListDevices() ([]Device, error) { 102*b6fb3261SAndroid Build Coastguard Worker status := newStatus() 103*b6fb3261SAndroid Build Coastguard Worker devicesList := C.TFE_ContextListDevices(c.c, status.c) 104*b6fb3261SAndroid Build Coastguard Worker if err := status.Err(); err != nil { 105*b6fb3261SAndroid Build Coastguard Worker return nil, fmt.Errorf("SessionListDevices() failed: %v", err) 106*b6fb3261SAndroid Build Coastguard Worker } 107*b6fb3261SAndroid Build Coastguard Worker defer C.TF_DeleteDeviceList(devicesList) 108*b6fb3261SAndroid Build Coastguard Worker return deviceSliceFromDeviceList(devicesList) 109*b6fb3261SAndroid Build Coastguard Worker} 110