xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/xplane_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/xplane_utils.h"
17 
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
30 #include "tensorflow/core/profiler/utils/math_utils.h"
31 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
32 #include "tensorflow/core/profiler/utils/xplane_builder.h"
33 #include "tensorflow/core/profiler/utils/xplane_schema.h"
34 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
35 
36 namespace tensorflow {
37 namespace profiler {
38 namespace {
39 
40 using ::testing::Property;
41 using ::testing::SizeIs;
42 using ::testing::UnorderedElementsAre;
43 
44 #if defined(PLATFORM_GOOGLE)
45 using ::testing::EqualsProto;
46 using ::testing::proto::IgnoringRepeatedFieldOrdering;
47 #endif
48 
CreateEvent(int64_t offset_ps,int64_t duration_ps)49 XEvent CreateEvent(int64_t offset_ps, int64_t duration_ps) {
50   XEvent event;
51   event.set_offset_ps(offset_ps);
52   event.set_duration_ps(duration_ps);
53   return event;
54 }
55 
TEST(XPlaneUtilsTest,AddAndRemovePlanes)56 TEST(XPlaneUtilsTest, AddAndRemovePlanes) {
57   XSpace space;
58 
59   auto* p1 = FindOrAddMutablePlaneWithName(&space, "p1");
60   EXPECT_EQ(p1, FindPlaneWithName(space, "p1"));
61   auto* p2 = FindOrAddMutablePlaneWithName(&space, "p2");
62   EXPECT_EQ(p2, FindPlaneWithName(space, "p2"));
63   auto* p3 = FindOrAddMutablePlaneWithName(&space, "p3");
64   EXPECT_EQ(p3, FindPlaneWithName(space, "p3"));
65 
66   // Removing a plane does not invalidate pointers to other planes.
67 
68   RemovePlane(&space, p2);
69   EXPECT_EQ(space.planes_size(), 2);
70   EXPECT_EQ(p1, FindPlaneWithName(space, "p1"));
71   EXPECT_EQ(p3, FindPlaneWithName(space, "p3"));
72 
73   RemovePlane(&space, p1);
74   EXPECT_EQ(space.planes_size(), 1);
75   EXPECT_EQ(p3, FindPlaneWithName(space, "p3"));
76 
77   RemovePlane(&space, p3);
78   EXPECT_EQ(space.planes_size(), 0);
79 }
80 
TEST(XPlaneUtilsTest,RemoveEmptyPlanes)81 TEST(XPlaneUtilsTest, RemoveEmptyPlanes) {
82   XSpace space;
83   RemoveEmptyPlanes(&space);
84   EXPECT_EQ(space.planes_size(), 0);
85 
86   auto* plane1 = space.add_planes();
87   plane1->set_name("p1");
88   plane1->add_lines()->set_name("p1l1");
89   plane1->add_lines()->set_name("p1l2");
90 
91   auto* plane2 = space.add_planes();
92   plane2->set_name("p2");
93 
94   auto* plane3 = space.add_planes();
95   plane3->set_name("p3");
96   plane3->add_lines()->set_name("p3l1");
97 
98   auto* plane4 = space.add_planes();
99   plane4->set_name("p4");
100 
101   RemoveEmptyPlanes(&space);
102   ASSERT_EQ(space.planes_size(), 2);
103   EXPECT_EQ(space.planes(0).name(), "p1");
104   EXPECT_EQ(space.planes(1).name(), "p3");
105 }
106 
TEST(XPlaneUtilsTest,RemoveEmptyLines)107 TEST(XPlaneUtilsTest, RemoveEmptyLines) {
108   XPlane plane;
109   RemoveEmptyLines(&plane);
110   EXPECT_EQ(plane.lines_size(), 0);
111 
112   auto* line1 = plane.add_lines();
113   line1->set_name("l1");
114   line1->add_events();
115   line1->add_events();
116 
117   auto* line2 = plane.add_lines();
118   line2->set_name("l2");
119 
120   auto* line3 = plane.add_lines();
121   line3->set_name("l3");
122   line3->add_events();
123 
124   auto* line4 = plane.add_lines();
125   line4->set_name("l4");
126 
127   RemoveEmptyLines(&plane);
128   ASSERT_EQ(plane.lines_size(), 2);
129   EXPECT_EQ(plane.lines(0).name(), "l1");
130   EXPECT_EQ(plane.lines(1).name(), "l3");
131 }
132 
TEST(XPlaneUtilsTest,RemoveLine)133 TEST(XPlaneUtilsTest, RemoveLine) {
134   XPlane plane;
135   const XLine* line1 = plane.add_lines();
136   const XLine* line2 = plane.add_lines();
137   const XLine* line3 = plane.add_lines();
138   RemoveLine(&plane, line2);
139   ASSERT_EQ(plane.lines_size(), 2);
140   EXPECT_EQ(&plane.lines(0), line1);
141   EXPECT_EQ(&plane.lines(1), line3);
142 }
143 
TEST(XPlaneUtilsTest,RemoveEvents)144 TEST(XPlaneUtilsTest, RemoveEvents) {
145   XLine line;
146   const XEvent* event1 = line.add_events();
147   const XEvent* event2 = line.add_events();
148   const XEvent* event3 = line.add_events();
149   const XEvent* event4 = line.add_events();
150   RemoveEvents(&line, {event1, event3});
151   ASSERT_EQ(line.events_size(), 2);
152   EXPECT_EQ(&line.events(0), event2);
153   EXPECT_EQ(&line.events(1), event4);
154 }
155 
TEST(XPlaneUtilsTest,SortXPlaneTest)156 TEST(XPlaneUtilsTest, SortXPlaneTest) {
157   XPlane plane;
158   XLine* line = plane.add_lines();
159   *line->add_events() = CreateEvent(200, 100);
160   *line->add_events() = CreateEvent(100, 100);
161   *line->add_events() = CreateEvent(120, 50);
162   *line->add_events() = CreateEvent(120, 30);
163   SortXPlane(&plane);
164   ASSERT_EQ(plane.lines_size(), 1);
165   ASSERT_EQ(plane.lines(0).events_size(), 4);
166   EXPECT_EQ(plane.lines(0).events(0).offset_ps(), 100);
167   EXPECT_EQ(plane.lines(0).events(0).duration_ps(), 100);
168   EXPECT_EQ(plane.lines(0).events(1).offset_ps(), 120);
169   EXPECT_EQ(plane.lines(0).events(1).duration_ps(), 50);
170   EXPECT_EQ(plane.lines(0).events(2).offset_ps(), 120);
171   EXPECT_EQ(plane.lines(0).events(2).duration_ps(), 30);
172   EXPECT_EQ(plane.lines(0).events(3).offset_ps(), 200);
173   EXPECT_EQ(plane.lines(0).events(3).duration_ps(), 100);
174 }
175 
176 namespace {
177 
CreateXLine(XPlaneBuilder * plane,absl::string_view name,absl::string_view display,int64_t id,int64_t timestamp_ns)178 XLineBuilder CreateXLine(XPlaneBuilder* plane, absl::string_view name,
179                          absl::string_view display, int64_t id,
180                          int64_t timestamp_ns) {
181   XLineBuilder line = plane->GetOrCreateLine(id);
182   line.SetName(name);
183   line.SetTimestampNs(timestamp_ns);
184   line.SetDisplayNameIfEmpty(display);
185   return line;
186 }
187 
CreateXEvent(XPlaneBuilder * plane,XLineBuilder line,absl::string_view event_name,absl::optional<absl::string_view> display,int64_t offset_ns,int64_t duration_ns)188 XEventBuilder CreateXEvent(XPlaneBuilder* plane, XLineBuilder line,
189                            absl::string_view event_name,
190                            absl::optional<absl::string_view> display,
191                            int64_t offset_ns, int64_t duration_ns) {
192   XEventMetadata* event_metadata = plane->GetOrCreateEventMetadata(event_name);
193   if (display) event_metadata->set_display_name(std::string(*display));
194   XEventBuilder event = line.AddEvent(*event_metadata);
195   event.SetOffsetNs(offset_ns);
196   event.SetDurationNs(duration_ns);
197   return event;
198 }
199 
200 template <typename T, typename V>
CreateXStats(XPlaneBuilder * plane,T * stats_owner,absl::string_view stats_name,V stats_value)201 void CreateXStats(XPlaneBuilder* plane, T* stats_owner,
202                   absl::string_view stats_name, V stats_value) {
203   stats_owner->AddStatValue(*plane->GetOrCreateStatMetadata(stats_name),
204                             stats_value);
205 }
206 
CheckXLine(const XLine & line,absl::string_view name,absl::string_view display,int64_t start_time_ns,int64_t events_size)207 void CheckXLine(const XLine& line, absl::string_view name,
208                 absl::string_view display, int64_t start_time_ns,
209                 int64_t events_size) {
210   EXPECT_EQ(line.name(), name);
211   EXPECT_EQ(line.display_name(), display);
212   EXPECT_EQ(line.timestamp_ns(), start_time_ns);
213   EXPECT_EQ(line.events_size(), events_size);
214 }
215 
CheckXEvent(const XEvent & event,const XPlane & plane,absl::string_view name,absl::string_view display,int64_t offset_ns,int64_t duration_ns,int64_t stats_size)216 void CheckXEvent(const XEvent& event, const XPlane& plane,
217                  absl::string_view name, absl::string_view display,
218                  int64_t offset_ns, int64_t duration_ns, int64_t stats_size) {
219   const XEventMetadata& event_metadata =
220       plane.event_metadata().at(event.metadata_id());
221   EXPECT_EQ(event_metadata.name(), name);
222   EXPECT_EQ(event_metadata.display_name(), display);
223   EXPECT_EQ(event.offset_ps(), NanoToPico(offset_ns));
224   EXPECT_EQ(event.duration_ps(), NanoToPico(duration_ns));
225   EXPECT_EQ(event.stats_size(), stats_size);
226 }
227 }  // namespace
228 
TEST(XPlaneUtilsTest,MergeXPlaneTest)229 TEST(XPlaneUtilsTest, MergeXPlaneTest) {
230   XPlane src_plane, dst_plane;
231   constexpr int64_t kLineIdOnlyInSrcPlane = 1LL;
232   constexpr int64_t kLineIdOnlyInDstPlane = 2LL;
233   constexpr int64_t kLineIdInBothPlanes = 3LL;   // src start ts < dst start ts
234   constexpr int64_t kLineIdInBothPlanes2 = 4LL;  // src start ts > dst start ts
235 
236   {  // Populate the source plane.
237     XPlaneBuilder src(&src_plane);
238     CreateXStats(&src, &src, "plane_stat1", 1);    // only in source.
239     CreateXStats(&src, &src, "plane_stat3", 3.0);  // shared by source/dest.
240 
241     auto l1 = CreateXLine(&src, "l1", "d1", kLineIdOnlyInSrcPlane, 100);
242     auto e1 = CreateXEvent(&src, l1, "event1", "display1", 1, 2);
243     CreateXStats(&src, &e1, "event_stat1", 2.0);
244     auto e2 = CreateXEvent(&src, l1, "event2", absl::nullopt, 3, 4);
245     CreateXStats(&src, &e2, "event_stat2", 3);
246 
247     auto l2 = CreateXLine(&src, "l2", "d2", kLineIdInBothPlanes, 200);
248     auto e3 = CreateXEvent(&src, l2, "event3", absl::nullopt, 5, 7);
249     CreateXStats(&src, &e3, "event_stat3", 2.0);
250     auto e4 = CreateXEvent(&src, l2, "event4", absl::nullopt, 6, 8);
251     CreateXStats(&src, &e4, "event_stat4", 3);
252     CreateXStats(&src, &e4, "event_stat5", 3);
253 
254     auto l5 = CreateXLine(&src, "l5", "d5", kLineIdInBothPlanes2, 700);
255     CreateXEvent(&src, l5, "event51", absl::nullopt, 9, 10);
256     CreateXEvent(&src, l5, "event52", absl::nullopt, 11, 12);
257   }
258 
259   {  // Populate the destination plane.
260     XPlaneBuilder dst(&dst_plane);
261     CreateXStats(&dst, &dst, "plane_stat2", 2);  // only in dest
262     CreateXStats(&dst, &dst, "plane_stat3", 4);  // shared but different.
263 
264     auto l3 = CreateXLine(&dst, "l3", "d3", kLineIdOnlyInDstPlane, 300);
265     auto e5 = CreateXEvent(&dst, l3, "event5", absl::nullopt, 11, 2);
266     CreateXStats(&dst, &e5, "event_stat6", 2.0);
267     auto e6 = CreateXEvent(&dst, l3, "event6", absl::nullopt, 13, 4);
268     CreateXStats(&dst, &e6, "event_stat7", 3);
269 
270     auto l2 = CreateXLine(&dst, "l4", "d4", kLineIdInBothPlanes, 400);
271     auto e7 = CreateXEvent(&dst, l2, "event7", absl::nullopt, 15, 7);
272     CreateXStats(&dst, &e7, "event_stat8", 2.0);
273     auto e8 = CreateXEvent(&dst, l2, "event8", "display8", 16, 8);
274     CreateXStats(&dst, &e8, "event_stat9", 3);
275 
276     auto l6 = CreateXLine(&dst, "l6", "d6", kLineIdInBothPlanes2, 300);
277     CreateXEvent(&dst, l6, "event61", absl::nullopt, 21, 10);
278     CreateXEvent(&dst, l6, "event62", absl::nullopt, 22, 12);
279   }
280 
281   MergePlanes(src_plane, &dst_plane);
282 
283   XPlaneVisitor plane(&dst_plane);
284   EXPECT_EQ(dst_plane.lines_size(), 4);
285 
286   // Check plane level stats,
287   EXPECT_EQ(dst_plane.stats_size(), 3);
288   absl::flat_hash_map<absl::string_view, absl::string_view> plane_stats;
289   plane.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) {
290     if (stat.Name() == "plane_stat1") {
291       EXPECT_EQ(stat.IntValue(), 1);
292     } else if (stat.Name() == "plane_stat2") {
293       EXPECT_EQ(stat.IntValue(), 2);
294     } else if (stat.Name() == "plane_stat3") {
295       // XStat in src_plane overrides the counter-part in dst_plane.
296       EXPECT_EQ(stat.DoubleValue(), 3.0);
297     } else {
298       EXPECT_TRUE(false);
299     }
300   });
301 
302   // 3 plane level stats, 9 event level stats.
303   EXPECT_EQ(dst_plane.stat_metadata_size(), 12);
304 
305   {  // Old lines are untouched.
306     const XLine& line = dst_plane.lines(0);
307     CheckXLine(line, "l3", "d3", 300, 2);
308     CheckXEvent(line.events(0), dst_plane, "event5", "", 11, 2, 1);
309     CheckXEvent(line.events(1), dst_plane, "event6", "", 13, 4, 1);
310   }
311   {  // Lines with the same id are merged.
312     // src plane start timestamp > dst plane start timestamp
313     const XLine& line = dst_plane.lines(1);
314     // NOTE: use minimum start time of src/dst.
315     CheckXLine(line, "l4", "d4", 200, 4);
316     CheckXEvent(line.events(0), dst_plane, "event7", "", 215, 7, 1);
317     CheckXEvent(line.events(1), dst_plane, "event8", "display8", 216, 8, 1);
318     CheckXEvent(line.events(2), dst_plane, "event3", "", 5, 7, 1);
319     CheckXEvent(line.events(3), dst_plane, "event4", "", 6, 8, 2);
320   }
321   {  // Lines with the same id are merged.
322     // src plane start timestamp < dst plane start timestamp
323     const XLine& line = dst_plane.lines(2);
324     CheckXLine(line, "l6", "d6", 300, 4);
325     CheckXEvent(line.events(0), dst_plane, "event61", "", 21, 10, 0);
326     CheckXEvent(line.events(1), dst_plane, "event62", "", 22, 12, 0);
327     CheckXEvent(line.events(2), dst_plane, "event51", "", 409, 10, 0);
328     CheckXEvent(line.events(3), dst_plane, "event52", "", 411, 12, 0);
329   }
330   {  // Lines only in source are "copied".
331     const XLine& line = dst_plane.lines(3);
332     CheckXLine(line, "l1", "d1", 100, 2);
333     CheckXEvent(line.events(0), dst_plane, "event1", "display1", 1, 2, 1);
334     CheckXEvent(line.events(1), dst_plane, "event2", "", 3, 4, 1);
335   }
336 }
337 
TEST(XPlaneUtilsTest,FindPlanesWithPrefix)338 TEST(XPlaneUtilsTest, FindPlanesWithPrefix) {
339   XSpace xspace;
340   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:0");
341   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:1");
342   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:2");
343   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:3");
344   XPlane* p4 = FindOrAddMutablePlaneWithName(&xspace, "test-do-not-include:0");
345 
346   std::vector<const XPlane*> xplanes =
347       FindPlanesWithPrefix(xspace, "test-prefix");
348   ASSERT_EQ(4, xplanes.size());
349   for (const XPlane* plane : xplanes) {
350     ASSERT_NE(p4, plane);
351   }
352 }
353 
TEST(XplaneUtilsTest,FindMutablePlanesWithPrefix)354 TEST(XplaneUtilsTest, FindMutablePlanesWithPrefix) {
355   XSpace xspace;
356   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:0");
357   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:1");
358   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:2");
359   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:3");
360   XPlane* p4 = FindOrAddMutablePlaneWithName(&xspace, "test-do-not-include:0");
361 
362   std::vector<XPlane*> xplanes =
363       FindMutablePlanesWithPrefix(&xspace, "test-prefix");
364   ASSERT_EQ(4, xplanes.size());
365   for (XPlane* plane : xplanes) {
366     ASSERT_NE(p4, plane);
367   }
368 }
369 
TEST(XplaneUtilsTest,FindPlanesWithPredicate)370 TEST(XplaneUtilsTest, FindPlanesWithPredicate) {
371   XSpace xspace;
372   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:0");
373   XPlane* p1 = FindOrAddMutablePlaneWithName(&xspace, "test-prefix:1");
374 
375   std::vector<const XPlane*> xplanes = FindPlanes(
376       xspace,
377       [](const XPlane& xplane) { return xplane.name() == "test-prefix:1"; });
378   ASSERT_EQ(1, xplanes.size());
379   ASSERT_EQ(p1, xplanes[0]);
380 }
381 
TEST(XplaneUtilsTest,FindMutablePlanesWithPredicate)382 TEST(XplaneUtilsTest, FindMutablePlanesWithPredicate) {
383   XSpace xspace;
384   FindOrAddMutablePlaneWithName(&xspace, "test-prefix:0");
385   XPlane* p1 = FindOrAddMutablePlaneWithName(&xspace, "test-prefix:1");
386 
387   std::vector<XPlane*> xplanes = FindMutablePlanes(
388       &xspace, [](XPlane& xplane) { return xplane.name() == "test-prefix:1"; });
389   ASSERT_EQ(1, xplanes.size());
390   ASSERT_EQ(p1, xplanes[0]);
391 }
392 
TEST(XplaneUtilsTest,TestAggregateXPlanes)393 TEST(XplaneUtilsTest, TestAggregateXPlanes) {
394   XPlane xplane;
395   XPlaneBuilder builder(&xplane);
396   XEventMetadata* event_metadata1 = builder.GetOrCreateEventMetadata(1);
397   event_metadata1->set_name("EventMetadata1");
398   XEventMetadata* event_metadata2 = builder.GetOrCreateEventMetadata(2);
399   event_metadata2->set_name("EventMetadata2");
400   XEventMetadata* event_metadata3 = builder.GetOrCreateEventMetadata(3);
401   event_metadata3->set_name("EventMetadata3");
402   XEventMetadata* event_metadata4 = builder.GetOrCreateEventMetadata(4);
403   event_metadata4->set_name("EventMetadata4");
404 
405   XLineBuilder line = builder.GetOrCreateLine(1);
406   line.SetName(kTensorFlowOpLineName);
407   XEventBuilder event1 = line.AddEvent(*event_metadata1);
408   event1.SetOffsetNs(0);
409   event1.SetDurationNs(5);
410   XEventBuilder event3 = line.AddEvent(*event_metadata3);
411   event3.SetOffsetNs(0);
412   event3.SetDurationNs(2);
413   XEventBuilder event2 = line.AddEvent(*event_metadata2);
414   event2.SetOffsetNs(5);
415   event2.SetDurationNs(5);
416   XEventBuilder event4 = line.AddEvent(*event_metadata2);
417   event4.SetOffsetNs(10);
418   event4.SetDurationNs(5);
419   XEventBuilder event5 = line.AddEvent(*event_metadata4);
420   event5.SetOffsetNs(15);
421   event5.SetDurationNs(6);
422   XEventBuilder event6 = line.AddEvent(*event_metadata1);
423   event6.SetOffsetNs(15);
424   event6.SetDurationNs(4);
425   XEventBuilder event7 = line.AddEvent(*event_metadata3);
426   event7.SetOffsetNs(15);
427   event7.SetDurationNs(3);
428 
429   XPlane aggregated_xplane;
430   AggregateXPlane(xplane, aggregated_xplane);
431 
432 // Protobuf matchers are unavailable in OSS (b/169705709)
433 #if defined(PLATFORM_GOOGLE)
434   ASSERT_THAT(aggregated_xplane,
435               IgnoringRepeatedFieldOrdering(EqualsProto(
436                   R"pb(lines {
437                          id: 1
438                          name: "TensorFlow Ops"
439                          events {
440                            metadata_id: 1
441                            duration_ps: 9000
442                            stats { metadata_id: 1 int64_value: 4000 }
443                            stats { metadata_id: 2 int64_value: 4000 }
444                            num_occurrences: 2
445                          }
446                          events {
447                            metadata_id: 3
448                            duration_ps: 5000
449                            stats { metadata_id: 1 int64_value: 2000 }
450                            num_occurrences: 2
451                          }
452                          events {
453                            metadata_id: 2
454                            duration_ps: 10000
455                            stats { metadata_id: 1 int64_value: 5000 }
456                            num_occurrences: 2
457                          }
458                          events {
459                            metadata_id: 4
460                            duration_ps: 6000
461                            stats { metadata_id: 2 int64_value: 2000 }
462                            num_occurrences: 1
463                          }
464                        }
465                        event_metadata {
466                          key: 1
467                          value { id: 1 name: "EventMetadata1" }
468                        }
469                        event_metadata {
470                          key: 2
471                          value { id: 2 name: "EventMetadata2" }
472                        }
473                        event_metadata {
474                          key: 3
475                          value { id: 3 name: "EventMetadata3" }
476                        }
477                        event_metadata {
478                          key: 4
479                          value { id: 4 name: "EventMetadata4" }
480                        }
481                        stat_metadata {
482                          key: 1
483                          value { id: 1 name: "min_duration_ps" }
484                        }
485                        stat_metadata {
486                          key: 2
487                          value { id: 2 name: "self_duration_ps" }
488                        }
489                   )pb")));
490 #endif
491 }
492 
TEST(XPlanuUtilsTest,TestInstantEventDoesNotFail)493 TEST(XPlanuUtilsTest, TestInstantEventDoesNotFail) {
494   XPlane xplane;
495   XPlaneBuilder xplane_builder(&xplane);
496   XEventMetadata* event_metadata1 = xplane_builder.GetOrCreateEventMetadata(1);
497   XEventMetadata* event_metadata2 = xplane_builder.GetOrCreateEventMetadata(2);
498 
499   XLineBuilder line = xplane_builder.GetOrCreateLine(1);
500   line.SetName(kTensorFlowOpLineName);
501   XEventBuilder event1 = line.AddEvent(*event_metadata1);
502   XEventBuilder event2 = line.AddEvent(*event_metadata2);
503 
504   event1.SetOffsetNs(1);
505   event1.SetDurationNs(0);
506   event2.SetOffsetNs(1);
507   event2.SetDurationNs(0);
508 
509   XPlane aggregated_xplane;
510   AggregateXPlane(xplane, aggregated_xplane);
511 
512   EXPECT_THAT(aggregated_xplane.lines(),
513               UnorderedElementsAre(Property(&XLine::events, SizeIs(2))));
514 }
515 
TEST(XplaneutilsTest,TestEventMetadataStatsAreCopied)516 TEST(XplaneutilsTest, TestEventMetadataStatsAreCopied) {
517   XPlane xplane;
518   XPlaneBuilder xplane_builder(&xplane);
519   XEventMetadata* event_metadata = xplane_builder.GetOrCreateEventMetadata(1);
520 
521   XStatsBuilder<XEventMetadata> stats(event_metadata, &xplane_builder);
522   stats.AddStatValue(
523       *xplane_builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)),
524       "TestFunction");
525   XLineBuilder line = xplane_builder.GetOrCreateLine(1);
526   line.SetName(kTensorFlowOpLineName);
527   XEventBuilder event = line.AddEvent(*event_metadata);
528   event.SetDurationNs(0);
529   event.SetOffsetNs(0);
530 
531   XPlane aggregated_xplane;
532   AggregateXPlane(xplane, aggregated_xplane);
533 
534   XPlaneVisitor visitor = CreateTfXPlaneVisitor(&aggregated_xplane);
535 
536   XEventMetadataVisitor metadata_visitor(&visitor, visitor.GetEventMetadata(1));
537   std::optional<XStatVisitor> stat = metadata_visitor.GetStat(StatType::kTfOp);
538 
539   ASSERT_TRUE(stat.has_value());
540   EXPECT_EQ(stat->Name(), "tf_op");
541   EXPECT_EQ(stat->StrOrRefValue(), "TestFunction");
542 }
543 
TEST(XplaneutilsTest,TestEventMetadataStatsAreCopiedForRefValue)544 TEST(XplaneutilsTest, TestEventMetadataStatsAreCopiedForRefValue) {
545   XPlane xplane;
546   XPlaneBuilder xplane_builder(&xplane);
547   XEventMetadata* event_metadata = xplane_builder.GetOrCreateEventMetadata(1);
548 
549   XStatsBuilder<XEventMetadata> stats(event_metadata, &xplane_builder);
550   stats.AddStatValue(
551       *xplane_builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)),
552       *xplane_builder.GetOrCreateStatMetadata("TestFunction"));
553   XLineBuilder line = xplane_builder.GetOrCreateLine(1);
554   line.SetName(kTensorFlowOpLineName);
555   XEventBuilder event = line.AddEvent(*event_metadata);
556   event.SetDurationNs(0);
557   event.SetOffsetNs(0);
558 
559   XPlane aggregated_xplane;
560   AggregateXPlane(xplane, aggregated_xplane);
561 
562   XPlaneVisitor visitor = CreateTfXPlaneVisitor(&aggregated_xplane);
563 
564   XEventMetadataVisitor metadata_visitor(&visitor, visitor.GetEventMetadata(1));
565   std::optional<XStatVisitor> stat = metadata_visitor.GetStat(StatType::kTfOp);
566 
567   ASSERT_TRUE(stat.has_value());
568   EXPECT_EQ(stat->Name(), "tf_op");
569   EXPECT_EQ(stat->StrOrRefValue(), "TestFunction");
570 }
571 
572 }  // namespace
573 }  // namespace profiler
574 }  // namespace tensorflow
575