1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/examples/models/llama3_2_vision/cross_attention/cross_attention_mask.h>
10 
11 #include <gtest/gtest.h>
12 
13 using namespace ::testing;
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using exec_aten::TensorImpl;
17 
TEST(CrossAttentxnMaskTest,TestCrossAttentionMask)18 TEST(CrossAttentxnMaskTest, TestCrossAttentionMask) {
19   std::vector<int> tokens = {
20       1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415};
21 
22   // Initialize image tensors.
23   TensorImpl::SizesType sizes[2] = {2, 2};
24   TensorImpl::DimOrderType dim_order[2] = {0, 1};
25   TensorImpl::StridesType strides[2] = {2, 1};
26 
27   int32_t a_data[4] = {1, 2, 3, 4};
28   auto a_impl =
29       TensorImpl(ScalarType::Int, 2, sizes, a_data, dim_order, strides);
30   Tensor a(&a_impl);
31 
32   int32_t b_data[4] = {5, 6, 7, 8};
33   auto b_impl =
34       TensorImpl(ScalarType::Int, 2, sizes, b_data, dim_order, strides);
35   Tensor b(&b_impl);
36 
37   int32_t c_data[4] = {9, 10, 11, 12};
38   auto c_impl =
39       TensorImpl(ScalarType::Int, 2, sizes, c_data, dim_order, strides);
40   Tensor c(&c_impl);
41 
42   std::vector<Tensor> images = {a, b, c};
43   std::vector<std::vector<int>> mask_data;
44   auto output_masks = example::cross_attention_mask(
45       tokens,
46       images,
47       /*tile_size=*/1,
48       /*patch_size=*/1,
49       /*image_token_id=*/1,
50       /*out=*/mask_data);
51 
52   // Check contents of the mask.
53   std::vector<std::vector<size_t>> expected_intervals = {
54       {0, 7}, {1, 7}, {7, 12}};
55   for (size_t mask_idx = 0; mask_idx < output_masks.size(); ++mask_idx) {
56     auto& output_tensor = output_masks[mask_idx];
57     for (size_t i = 0; i < output_tensor->size(0); ++i) {
58       for (size_t j = 0; j < output_tensor->strides()[0]; ++j) {
59         size_t unrolled_index = i * output_tensor->strides()[0] + j;
60         if (i >= expected_intervals[mask_idx][0] &&
61             i < expected_intervals[mask_idx][1]) {
62           EXPECT_EQ(output_tensor->const_data_ptr<int>()[unrolled_index], 1);
63         } else {
64           EXPECT_EQ(output_tensor->const_data_ptr<int>()[unrolled_index], 0);
65         }
66       }
67     }
68   }
69 }
70