xref: /aosp_15_r20/external/tensorflow/tensorflow/go/context.go (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1*b6fb3261SAndroid Build Coastguard Worker/*
2*b6fb3261SAndroid Build Coastguard WorkerCopyright 2018 The TensorFlow Authors. All Rights Reserved.
3*b6fb3261SAndroid Build Coastguard Worker
4*b6fb3261SAndroid Build Coastguard WorkerLicensed under the Apache License, Version 2.0 (the "License");
5*b6fb3261SAndroid Build Coastguard Workeryou may not use this file except in compliance with the License.
6*b6fb3261SAndroid Build Coastguard WorkerYou may obtain a copy of the License at
7*b6fb3261SAndroid Build Coastguard Worker
8*b6fb3261SAndroid Build Coastguard Worker    http://www.apache.org/licenses/LICENSE-2.0
9*b6fb3261SAndroid Build Coastguard Worker
10*b6fb3261SAndroid Build Coastguard WorkerUnless required by applicable law or agreed to in writing, software
11*b6fb3261SAndroid Build Coastguard Workerdistributed under the License is distributed on an "AS IS" BASIS,
12*b6fb3261SAndroid Build Coastguard WorkerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*b6fb3261SAndroid Build Coastguard WorkerSee the License for the specific language governing permissions and
14*b6fb3261SAndroid Build Coastguard Workerlimitations under the License.
15*b6fb3261SAndroid Build Coastguard Worker*/
16*b6fb3261SAndroid Build Coastguard Worker
17*b6fb3261SAndroid Build Coastguard Workerpackage tensorflow
18*b6fb3261SAndroid Build Coastguard Worker
19*b6fb3261SAndroid Build Coastguard Worker// #include <stdlib.h>
20*b6fb3261SAndroid Build Coastguard Worker// #include "tensorflow/c/c_api.h"
21*b6fb3261SAndroid Build Coastguard Worker// #include "tensorflow/c/eager/c_api.h"
22*b6fb3261SAndroid Build Coastguard Workerimport "C"
23*b6fb3261SAndroid Build Coastguard Workerimport (
24*b6fb3261SAndroid Build Coastguard Worker	"fmt"
25*b6fb3261SAndroid Build Coastguard Worker	"runtime"
26*b6fb3261SAndroid Build Coastguard Worker)
27*b6fb3261SAndroid Build Coastguard Worker
28*b6fb3261SAndroid Build Coastguard Worker// ContextOptions contains configuration information for a session
29*b6fb3261SAndroid Build Coastguard Workertype ContextOptions struct {
30*b6fb3261SAndroid Build Coastguard Worker	// Config is a binary-serialized representation of the
31*b6fb3261SAndroid Build Coastguard Worker	// tensorflow.ConfigProto protocol message
32*b6fb3261SAndroid Build Coastguard Worker	// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
33*b6fb3261SAndroid Build Coastguard Worker	Config []byte
34*b6fb3261SAndroid Build Coastguard Worker
35*b6fb3261SAndroid Build Coastguard Worker	// Sets the default execution mode
36*b6fb3261SAndroid Build Coastguard Worker	Async bool
37*b6fb3261SAndroid Build Coastguard Worker}
38*b6fb3261SAndroid Build Coastguard Worker
39*b6fb3261SAndroid Build Coastguard Worker// c converts the ContextOptions to the C API's TF_ContextOptions.
40*b6fb3261SAndroid Build Coastguard Worker// Caller takes ownership of returned object.
41*b6fb3261SAndroid Build Coastguard Workerfunc (o *ContextOptions) c() (*C.TFE_ContextOptions, error) {
42*b6fb3261SAndroid Build Coastguard Worker	opt := C.TFE_NewContextOptions()
43*b6fb3261SAndroid Build Coastguard Worker	if o == nil {
44*b6fb3261SAndroid Build Coastguard Worker		return opt, nil
45*b6fb3261SAndroid Build Coastguard Worker	}
46*b6fb3261SAndroid Build Coastguard Worker
47*b6fb3261SAndroid Build Coastguard Worker	if sz := len(o.Config); sz > 0 {
48*b6fb3261SAndroid Build Coastguard Worker		status := newStatus()
49*b6fb3261SAndroid Build Coastguard Worker		cConfig := C.CBytes(o.Config)
50*b6fb3261SAndroid Build Coastguard Worker		C.TFE_ContextOptionsSetConfig(opt, cConfig, C.size_t(sz), status.c)
51*b6fb3261SAndroid Build Coastguard Worker		C.free(cConfig)
52*b6fb3261SAndroid Build Coastguard Worker		if err := status.Err(); err != nil {
53*b6fb3261SAndroid Build Coastguard Worker			C.TFE_DeleteContextOptions(opt)
54*b6fb3261SAndroid Build Coastguard Worker			return nil, fmt.Errorf("invalid ContextOptions.Config: %v", err)
55*b6fb3261SAndroid Build Coastguard Worker		}
56*b6fb3261SAndroid Build Coastguard Worker	}
57*b6fb3261SAndroid Build Coastguard Worker
58*b6fb3261SAndroid Build Coastguard Worker	var async uint8
59*b6fb3261SAndroid Build Coastguard Worker	if o.Async {
60*b6fb3261SAndroid Build Coastguard Worker		async = 1
61*b6fb3261SAndroid Build Coastguard Worker	}
62*b6fb3261SAndroid Build Coastguard Worker	C.TFE_ContextOptionsSetAsync(opt, C.uchar(async))
63*b6fb3261SAndroid Build Coastguard Worker
64*b6fb3261SAndroid Build Coastguard Worker	return opt, nil
65*b6fb3261SAndroid Build Coastguard Worker}
66*b6fb3261SAndroid Build Coastguard Worker
67*b6fb3261SAndroid Build Coastguard Worker// Context for executing operations eagerly.
68*b6fb3261SAndroid Build Coastguard Worker//
69*b6fb3261SAndroid Build Coastguard Worker// A Context allows operations to be executed immediately. It encapsulates
70*b6fb3261SAndroid Build Coastguard Worker// information such as the available devices, resource manager etc. It also
71*b6fb3261SAndroid Build Coastguard Worker// allows the user to configure execution using a ConfigProto, as they can
72*b6fb3261SAndroid Build Coastguard Worker// configure a Session when executing a Graph.
73*b6fb3261SAndroid Build Coastguard Workertype Context struct {
74*b6fb3261SAndroid Build Coastguard Worker	c *C.TFE_Context
75*b6fb3261SAndroid Build Coastguard Worker}
76*b6fb3261SAndroid Build Coastguard Worker
77*b6fb3261SAndroid Build Coastguard Worker// NewContext creates a new context for eager execution.
78*b6fb3261SAndroid Build Coastguard Worker// options may be nil to use the default options.
79*b6fb3261SAndroid Build Coastguard Workerfunc NewContext(options *ContextOptions) (*Context, error) {
80*b6fb3261SAndroid Build Coastguard Worker	status := newStatus()
81*b6fb3261SAndroid Build Coastguard Worker	cOpt, err := options.c()
82*b6fb3261SAndroid Build Coastguard Worker	if err != nil {
83*b6fb3261SAndroid Build Coastguard Worker		return nil, err
84*b6fb3261SAndroid Build Coastguard Worker	}
85*b6fb3261SAndroid Build Coastguard Worker	defer C.TFE_DeleteContextOptions(cOpt)
86*b6fb3261SAndroid Build Coastguard Worker	cContext := C.TFE_NewContext(cOpt, status.c)
87*b6fb3261SAndroid Build Coastguard Worker	if err := status.Err(); err != nil {
88*b6fb3261SAndroid Build Coastguard Worker		return nil, err
89*b6fb3261SAndroid Build Coastguard Worker	}
90*b6fb3261SAndroid Build Coastguard Worker
91*b6fb3261SAndroid Build Coastguard Worker	c := &Context{c: cContext}
92*b6fb3261SAndroid Build Coastguard Worker	runtime.SetFinalizer(c, (*Context).finalizer)
93*b6fb3261SAndroid Build Coastguard Worker	return c, nil
94*b6fb3261SAndroid Build Coastguard Worker}
95*b6fb3261SAndroid Build Coastguard Worker
96*b6fb3261SAndroid Build Coastguard Workerfunc (c *Context) finalizer() {
97*b6fb3261SAndroid Build Coastguard Worker	C.TFE_DeleteContext(c.c)
98*b6fb3261SAndroid Build Coastguard Worker}
99*b6fb3261SAndroid Build Coastguard Worker
100*b6fb3261SAndroid Build Coastguard Worker// ListDevices returns the list of devices associated with a Context.
101*b6fb3261SAndroid Build Coastguard Workerfunc (c *Context) ListDevices() ([]Device, error) {
102*b6fb3261SAndroid Build Coastguard Worker	status := newStatus()
103*b6fb3261SAndroid Build Coastguard Worker	devicesList := C.TFE_ContextListDevices(c.c, status.c)
104*b6fb3261SAndroid Build Coastguard Worker	if err := status.Err(); err != nil {
105*b6fb3261SAndroid Build Coastguard Worker		return nil, fmt.Errorf("SessionListDevices() failed: %v", err)
106*b6fb3261SAndroid Build Coastguard Worker	}
107*b6fb3261SAndroid Build Coastguard Worker	defer C.TF_DeleteDeviceList(devicesList)
108*b6fb3261SAndroid Build Coastguard Worker	return deviceSliceFromDeviceList(devicesList)
109*b6fb3261SAndroid Build Coastguard Worker}
110