xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/cloud/http_request_fake.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
17 #define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
18 
19 #include <algorithm>
20 #include <fstream>
21 #include <string>
22 #include <vector>
23 
24 #include <curl/curl.h>
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/cloud/curl_http_request.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/platform/stringpiece.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
37 /// Fake HttpRequest for testing.
38 class FakeHttpRequest : public CurlHttpRequest {
39  public:
40   /// Return the response for the given request.
FakeHttpRequest(const string & request,const string & response)41   FakeHttpRequest(const string& request, const string& response)
42       : FakeHttpRequest(request, response, OkStatus(), nullptr, {}, 200) {}
43 
44   /// Return the response with headers for the given request.
FakeHttpRequest(const string & request,const string & response,const std::map<string,string> & response_headers)45   FakeHttpRequest(const string& request, const string& response,
46                   const std::map<string, string>& response_headers)
47       : FakeHttpRequest(request, response, OkStatus(), nullptr,
48                         response_headers, 200) {}
49 
50   /// \brief Return the response for the request and capture the POST body.
51   ///
52   /// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string & request,const string & response,string * captured_post_body)53   FakeHttpRequest(const string& request, const string& response,
54                   string* captured_post_body)
55       : FakeHttpRequest(request, response, OkStatus(), captured_post_body, {},
56                         200) {}
57 
58   /// \brief Return the response and the status for the given request.
FakeHttpRequest(const string & request,const string & response,Status response_status,uint64 response_code)59   FakeHttpRequest(const string& request, const string& response,
60                   Status response_status, uint64 response_code)
61       : FakeHttpRequest(request, response, response_status, nullptr, {},
62                         response_code) {}
63 
64   /// \brief Return the response and the status for the given request
65   ///  and capture the POST body.
66   ///
67   /// Post body is not expected to be a part of the 'request' parameter.
FakeHttpRequest(const string & request,const string & response,Status response_status,string * captured_post_body,const std::map<string,string> & response_headers,uint64 response_code)68   FakeHttpRequest(const string& request, const string& response,
69                   Status response_status, string* captured_post_body,
70                   const std::map<string, string>& response_headers,
71                   uint64 response_code)
72       : expected_request_(request),
73         response_(response),
74         response_status_(response_status),
75         captured_post_body_(captured_post_body),
76         response_headers_(response_headers),
77         response_code_(response_code) {}
78 
SetUri(const string & uri)79   void SetUri(const string& uri) override {
80     actual_uri_ += "Uri: " + uri + "\n";
81   }
SetRange(uint64 start,uint64 end)82   void SetRange(uint64 start, uint64 end) override {
83     actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
84   }
AddHeader(const string & name,const string & value)85   void AddHeader(const string& name, const string& value) override {
86     actual_request_ += "Header " + name + ": " + value + "\n";
87   }
AddAuthBearerHeader(const string & auth_token)88   void AddAuthBearerHeader(const string& auth_token) override {
89     actual_request_ += "Auth Token: " + auth_token + "\n";
90   }
SetDeleteRequest()91   void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; }
SetPutFromFile(const string & body_filepath,size_t offset)92   Status SetPutFromFile(const string& body_filepath, size_t offset) override {
93     std::ifstream stream(body_filepath);
94     const string& content = string(std::istreambuf_iterator<char>(stream),
95                                    std::istreambuf_iterator<char>())
96                                 .substr(offset);
97     actual_request_ += "Put body: " + content + "\n";
98     return OkStatus();
99   }
SetPostFromBuffer(const char * buffer,size_t size)100   void SetPostFromBuffer(const char* buffer, size_t size) override {
101     if (captured_post_body_) {
102       *captured_post_body_ = string(buffer, size);
103     } else {
104       actual_request_ +=
105           strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
106     }
107   }
SetPutEmptyBody()108   void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; }
SetPostEmptyBody()109   void SetPostEmptyBody() override {
110     if (captured_post_body_) {
111       *captured_post_body_ = "<empty>";
112     } else {
113       actual_request_ += "Post: yes\n";
114     }
115   }
SetResultBuffer(std::vector<char> * buffer)116   void SetResultBuffer(std::vector<char>* buffer) override {
117     buffer->clear();
118     buffer_ = buffer;
119   }
SetResultBufferDirect(char * buffer,size_t size)120   void SetResultBufferDirect(char* buffer, size_t size) override {
121     direct_result_buffer_ = buffer;
122     direct_result_buffer_size_ = size;
123   }
GetResultBufferDirectBytesTransferred()124   size_t GetResultBufferDirectBytesTransferred() override {
125     return direct_result_bytes_transferred_;
126   }
Send()127   Status Send() override {
128     EXPECT_EQ(expected_request_, actual_request())
129         << "Unexpected HTTP request.";
130     if (buffer_) {
131       buffer_->insert(buffer_->begin(), response_.data(),
132                       response_.data() + response_.size());
133     } else if (direct_result_buffer_ != nullptr) {
134       size_t bytes_to_copy =
135           std::min<size_t>(direct_result_buffer_size_, response_.size());
136       memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
137       direct_result_bytes_transferred_ += bytes_to_copy;
138     }
139     return response_status_;
140   }
141 
142   // This function just does a simple replacing of "/" with "%2F" instead of
143   // full url encoding.
EscapeString(const string & str)144   string EscapeString(const string& str) override {
145     const string victim = "/";
146     const string encoded = "%2F";
147 
148     string copy_str = str;
149     std::string::size_type n = 0;
150     while ((n = copy_str.find(victim, n)) != std::string::npos) {
151       copy_str.replace(n, victim.size(), encoded);
152       n += encoded.size();
153     }
154     return copy_str;
155   }
156 
GetResponseHeader(const string & name)157   string GetResponseHeader(const string& name) const override {
158     const auto header = response_headers_.find(name);
159     return header != response_headers_.end() ? header->second : "";
160   }
161 
GetResponseCode()162   virtual uint64 GetResponseCode() const override { return response_code_; }
163 
SetTimeouts(uint32 connection,uint32 inactivity,uint32 total)164   void SetTimeouts(uint32 connection, uint32 inactivity,
165                    uint32 total) override {
166     actual_request_ += strings::StrCat("Timeouts: ", connection, " ",
167                                        inactivity, " ", total, "\n");
168   }
169 
170  private:
actual_request()171   string actual_request() const {
172     string s;
173     s.append(actual_uri_);
174     s.append(actual_request_);
175     return s;
176   }
177 
178   std::vector<char>* buffer_ = nullptr;
179   char* direct_result_buffer_ = nullptr;
180   size_t direct_result_buffer_size_ = 0;
181   size_t direct_result_bytes_transferred_ = 0;
182   string expected_request_;
183   string actual_uri_;
184   string actual_request_;
185   string response_;
186   Status response_status_;
187   string* captured_post_body_ = nullptr;
188   std::map<string, string> response_headers_;
189   uint64 response_code_ = 0;
190 };
191 
192 /// Fake HttpRequest factory for testing.
193 class FakeHttpRequestFactory : public HttpRequest::Factory {
194  public:
FakeHttpRequestFactory(const std::vector<HttpRequest * > * requests)195   FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
196       : requests_(requests) {}
197 
~FakeHttpRequestFactory()198   ~FakeHttpRequestFactory() {
199     EXPECT_EQ(current_index_, requests_->size())
200         << "Not all expected requests were made.";
201   }
202 
Create()203   HttpRequest* Create() override {
204     EXPECT_LT(current_index_, requests_->size())
205         << "Too many calls of HttpRequest factory.";
206     return (*requests_)[current_index_++];
207   }
208 
209  private:
210   const std::vector<HttpRequest*>* requests_;
211   int current_index_ = 0;
212 };
213 
214 }  // namespace tensorflow
215 
216 #endif  // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
217