1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 17 #define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 18 19 #include <algorithm> 20 #include <fstream> 21 #include <string> 22 #include <vector> 23 24 #include <curl/curl.h> 25 #include "tensorflow/core/lib/core/status_test_util.h" 26 #include "tensorflow/core/platform/cloud/curl_http_request.h" 27 #include "tensorflow/core/platform/errors.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/protobuf.h" 30 #include "tensorflow/core/platform/status.h" 31 #include "tensorflow/core/platform/stringpiece.h" 32 #include "tensorflow/core/platform/test.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 37 /// Fake HttpRequest for testing. 38 class FakeHttpRequest : public CurlHttpRequest { 39 public: 40 /// Return the response for the given request. FakeHttpRequest(const string & request,const string & response)41 FakeHttpRequest(const string& request, const string& response) 42 : FakeHttpRequest(request, response, OkStatus(), nullptr, {}, 200) {} 43 44 /// Return the response with headers for the given request. FakeHttpRequest(const string & request,const string & response,const std::map<string,string> & response_headers)45 FakeHttpRequest(const string& request, const string& response, 46 const std::map<string, string>& response_headers) 47 : FakeHttpRequest(request, response, OkStatus(), nullptr, 48 response_headers, 200) {} 49 50 /// \brief Return the response for the request and capture the POST body. 51 /// 52 /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string & request,const string & response,string * captured_post_body)53 FakeHttpRequest(const string& request, const string& response, 54 string* captured_post_body) 55 : FakeHttpRequest(request, response, OkStatus(), captured_post_body, {}, 56 200) {} 57 58 /// \brief Return the response and the status for the given request. FakeHttpRequest(const string & request,const string & response,Status response_status,uint64 response_code)59 FakeHttpRequest(const string& request, const string& response, 60 Status response_status, uint64 response_code) 61 : FakeHttpRequest(request, response, response_status, nullptr, {}, 62 response_code) {} 63 64 /// \brief Return the response and the status for the given request 65 /// and capture the POST body. 66 /// 67 /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string & request,const string & response,Status response_status,string * captured_post_body,const std::map<string,string> & response_headers,uint64 response_code)68 FakeHttpRequest(const string& request, const string& response, 69 Status response_status, string* captured_post_body, 70 const std::map<string, string>& response_headers, 71 uint64 response_code) 72 : expected_request_(request), 73 response_(response), 74 response_status_(response_status), 75 captured_post_body_(captured_post_body), 76 response_headers_(response_headers), 77 response_code_(response_code) {} 78 SetUri(const string & uri)79 void SetUri(const string& uri) override { 80 actual_uri_ += "Uri: " + uri + "\n"; 81 } SetRange(uint64 start,uint64 end)82 void SetRange(uint64 start, uint64 end) override { 83 actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n"); 84 } AddHeader(const string & name,const string & value)85 void AddHeader(const string& name, const string& value) override { 86 actual_request_ += "Header " + name + ": " + value + "\n"; 87 } AddAuthBearerHeader(const string & auth_token)88 void AddAuthBearerHeader(const string& auth_token) override { 89 actual_request_ += "Auth Token: " + auth_token + "\n"; 90 } SetDeleteRequest()91 void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; } SetPutFromFile(const string & body_filepath,size_t offset)92 Status SetPutFromFile(const string& body_filepath, size_t offset) override { 93 std::ifstream stream(body_filepath); 94 const string& content = string(std::istreambuf_iterator<char>(stream), 95 std::istreambuf_iterator<char>()) 96 .substr(offset); 97 actual_request_ += "Put body: " + content + "\n"; 98 return OkStatus(); 99 } SetPostFromBuffer(const char * buffer,size_t size)100 void SetPostFromBuffer(const char* buffer, size_t size) override { 101 if (captured_post_body_) { 102 *captured_post_body_ = string(buffer, size); 103 } else { 104 actual_request_ += 105 strings::StrCat("Post body: ", StringPiece(buffer, size), "\n"); 106 } 107 } SetPutEmptyBody()108 void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; } SetPostEmptyBody()109 void SetPostEmptyBody() override { 110 if (captured_post_body_) { 111 *captured_post_body_ = "<empty>"; 112 } else { 113 actual_request_ += "Post: yes\n"; 114 } 115 } SetResultBuffer(std::vector<char> * buffer)116 void SetResultBuffer(std::vector<char>* buffer) override { 117 buffer->clear(); 118 buffer_ = buffer; 119 } SetResultBufferDirect(char * buffer,size_t size)120 void SetResultBufferDirect(char* buffer, size_t size) override { 121 direct_result_buffer_ = buffer; 122 direct_result_buffer_size_ = size; 123 } GetResultBufferDirectBytesTransferred()124 size_t GetResultBufferDirectBytesTransferred() override { 125 return direct_result_bytes_transferred_; 126 } Send()127 Status Send() override { 128 EXPECT_EQ(expected_request_, actual_request()) 129 << "Unexpected HTTP request."; 130 if (buffer_) { 131 buffer_->insert(buffer_->begin(), response_.data(), 132 response_.data() + response_.size()); 133 } else if (direct_result_buffer_ != nullptr) { 134 size_t bytes_to_copy = 135 std::min<size_t>(direct_result_buffer_size_, response_.size()); 136 memcpy(direct_result_buffer_, response_.data(), bytes_to_copy); 137 direct_result_bytes_transferred_ += bytes_to_copy; 138 } 139 return response_status_; 140 } 141 142 // This function just does a simple replacing of "/" with "%2F" instead of 143 // full url encoding. EscapeString(const string & str)144 string EscapeString(const string& str) override { 145 const string victim = "/"; 146 const string encoded = "%2F"; 147 148 string copy_str = str; 149 std::string::size_type n = 0; 150 while ((n = copy_str.find(victim, n)) != std::string::npos) { 151 copy_str.replace(n, victim.size(), encoded); 152 n += encoded.size(); 153 } 154 return copy_str; 155 } 156 GetResponseHeader(const string & name)157 string GetResponseHeader(const string& name) const override { 158 const auto header = response_headers_.find(name); 159 return header != response_headers_.end() ? header->second : ""; 160 } 161 GetResponseCode()162 virtual uint64 GetResponseCode() const override { return response_code_; } 163 SetTimeouts(uint32 connection,uint32 inactivity,uint32 total)164 void SetTimeouts(uint32 connection, uint32 inactivity, 165 uint32 total) override { 166 actual_request_ += strings::StrCat("Timeouts: ", connection, " ", 167 inactivity, " ", total, "\n"); 168 } 169 170 private: actual_request()171 string actual_request() const { 172 string s; 173 s.append(actual_uri_); 174 s.append(actual_request_); 175 return s; 176 } 177 178 std::vector<char>* buffer_ = nullptr; 179 char* direct_result_buffer_ = nullptr; 180 size_t direct_result_buffer_size_ = 0; 181 size_t direct_result_bytes_transferred_ = 0; 182 string expected_request_; 183 string actual_uri_; 184 string actual_request_; 185 string response_; 186 Status response_status_; 187 string* captured_post_body_ = nullptr; 188 std::map<string, string> response_headers_; 189 uint64 response_code_ = 0; 190 }; 191 192 /// Fake HttpRequest factory for testing. 193 class FakeHttpRequestFactory : public HttpRequest::Factory { 194 public: FakeHttpRequestFactory(const std::vector<HttpRequest * > * requests)195 FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests) 196 : requests_(requests) {} 197 ~FakeHttpRequestFactory()198 ~FakeHttpRequestFactory() { 199 EXPECT_EQ(current_index_, requests_->size()) 200 << "Not all expected requests were made."; 201 } 202 Create()203 HttpRequest* Create() override { 204 EXPECT_LT(current_index_, requests_->size()) 205 << "Too many calls of HttpRequest factory."; 206 return (*requests_)[current_index_++]; 207 } 208 209 private: 210 const std::vector<HttpRequest*>* requests_; 211 int current_index_ = 0; 212 }; 213 214 } // namespace tensorflow 215 216 #endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ 217