1 package com.android.onboarding.tasks
2 
3 import android.content.Context
4 import android.text.TextUtils
5 import android.util.Log
6 import com.android.onboarding.contracts.annotations.OnboardingNode
7 import com.android.onboarding.tasks.crossApp.CrossProcessTaskManager
8 import com.google.common.util.concurrent.ListenableFuture
9 import java.util.concurrent.ConcurrentHashMap
10 import kotlinx.coroutines.CoroutineScope
11 import kotlinx.coroutines.delay
12 import kotlinx.coroutines.guava.future
13 import kotlinx.coroutines.launch
14 
15 /**
16  * Base class for managing the execution and state of onboarding tasks within the onboarding
17  * process. This class provides common part of implementation for triggering tasks, monitoring their
18  * progress, and obtaining results from onboarding tasks.
19  */
20 abstract class AbstractOnboardingTaskManager(
21   protected val appContext: Context,
22   protected val coroutineScope: CoroutineScope,
23 ) : OnboardingTaskManager {
24 
25   // Mapping between onboarding task contracts and corresponding tasks.
26   private val contractAndTaskMap:
27     ConcurrentHashMap<Class<out OnboardingTaskContract<*, *>>, Class<out OnboardingTask<*, *, *>>>
28   private val taskStateManager = OnboardingTaskStateManager()
29 
30   init {
31     // Initialize the mapping between task contracts and tasks.
32     contractAndTaskMap = initializeContractAndTaskMap()
33   }
34 
35   /**
36    * Assign a component name for this task manager. The component name must reference an
37    * [OnboardingComponents] constant.
38    */
39   abstract val componentName: String
40 
41   /**
42    * Initializes a mapping between onboarding task contracts and corresponding onboarding tasks.
43    * This method should be overridden by implementing classes to provide custom mappings between
44    * specific onboarding task contracts and their corresponding task implementations for their
45    * application process.
46    *
47    * @return A map where the keys represent classes implementing the [OnboardingTaskContract], and
48    *   the values represent classes implementing the [ OnboardingTask]. The mapping specifies the
49    *   relationship between task contracts and their associated tasks. Implementing classes should
50    *   populate this map with their desired mappings.
51    */
initializeContractAndTaskMapnull52   abstract fun initializeContractAndTaskMap():
53     ConcurrentHashMap<Class<out OnboardingTaskContract<*, *>>, Class<out OnboardingTask<*, *, *>>>
54 
55   override fun <
56     TaskArgsT,
57     TaskResultT,
58     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
59   > runTask(taskContract: TaskContractT, taskArgs: TaskArgsT): OnboardingTaskToken {
60     val taskToken: OnboardingTaskToken
61 
62     if (isTaskRunInSameProcess(taskContract)) {
63       Log.i(TAG, "Run task: $taskContract in same process.")
64       val task =
65         tryCreateTaskInstance(taskContract::class.java) ?: return OnboardingTaskToken.INVALID
66       taskToken = OnboardingTaskToken(taskContract::class.java.name, taskContract.componentName)
67       // Update the task state as in progress immediately before running the task.
68       taskStateManager.updateTaskState(taskToken, OnboardingTaskState.InProgress<Nothing>())
69       // Run the task asynchronously.
70       coroutineScope.launch { performTask(taskContract, task, taskArgs, taskToken) }
71     } else {
72       Log.i(TAG, "Run task: $taskContract in cross process.")
73       // Cross process triggers task asynchronously.
74       taskToken =
75         CrossProcessTaskManager.getInstance(appContext, taskStateManager)
76           .runTask(taskContract, taskArgs)
77       // Mark the task in progress.
78       taskStateManager.updateTaskState(taskToken, OnboardingTaskState.InProgress<Nothing>())
79     }
80 
81     Log.d(TAG, "Return task token immediately.")
82     return taskToken
83   }
84 
85   override suspend fun <
86     TaskArgsT,
87     TaskResultT,
88     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
runTaskAndGetResultnull89   > runTaskAndGetResult(
90     taskContract: TaskContractT,
91     taskArgs: TaskArgsT,
92   ): OnboardingTaskState<TaskResultT> {
93     val task =
94       tryCreateTaskInstance(taskContract::class.java)
95         ?: return OnboardingTaskState.Failed(ERROR_INSTANTIATING_TASK)
96     val taskToken = OnboardingTaskToken(taskContract::class.java.name, taskComponentName = "")
97     // We have to update the task status as soon as possible to prevent immediate query status
98     // action.
99     taskStateManager.updateTaskState(taskToken, OnboardingTaskState.InProgress<Nothing>())
100 
101     // Execute the task and await its completion.
102     performTask(taskContract, task, taskArgs, taskToken)
103 
104     // Because task state includes different types of results in the list.
105     return getTaskState(taskToken)
106   }
107 
108   @Deprecated("Use new overload function - runTaskAndGetResult().")
109   override suspend fun <
110     TaskArgsT,
111     TaskResultT,
112     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
runTaskAndGetResultnull113   > runTaskAndGetResult(
114     taskContract: TaskContractT,
115     task: OnboardingTask<TaskArgsT, TaskResultT, TaskContractT>,
116     taskArgs: TaskArgsT,
117   ): OnboardingTaskState<TaskResultT> {
118     val taskToken = OnboardingTaskToken(taskContract::class.java.name, taskComponentName = "")
119     // We have to update the task status as soon as possible to prevent immediate query status
120     // action.
121     taskStateManager.updateTaskState(taskToken, OnboardingTaskState.InProgress<Nothing>())
122 
123     // Execute the task and await its completion.
124     performTask(taskContract, task, taskArgs, taskToken)
125 
126     return getTaskState(taskToken)
127   }
128 
129   override fun <
130     TaskArgsT,
131     TaskResultT,
132     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
runTaskAndGetResultAsyncnull133   > runTaskAndGetResultAsync(
134     taskContract: TaskContractT,
135     taskArgs: TaskArgsT,
136   ): ListenableFuture<OnboardingTaskState<TaskResultT>> {
137     return coroutineScope.future { runTaskAndGetResult(taskContract, taskArgs) }
138   }
139 
140   @Deprecated("Use new overload function - runTaskAndGetResultAsync().")
141   override fun <
142     TaskArgsT,
143     TaskResultT,
144     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
runTaskAndGetResultAsyncnull145   > runTaskAndGetResultAsync(
146     taskContract: TaskContractT,
147     task: OnboardingTask<TaskArgsT, TaskResultT, TaskContractT>,
148     taskArgs: TaskArgsT,
149   ): ListenableFuture<OnboardingTaskState<TaskResultT>> {
150     return coroutineScope.future { runTaskAndGetResult(taskContract, task, taskArgs) }
151   }
152 
getTaskStatenull153   override fun <TaskResultT> getTaskState(
154     taskToken: OnboardingTaskToken
155   ): OnboardingTaskState<TaskResultT> {
156     return taskStateManager.getTaskState(taskToken)
157   }
158 
waitForCompletednull159   override suspend fun <TaskResultT> waitForCompleted(
160     taskToken: OnboardingTaskToken
161   ): OnboardingTaskState<TaskResultT> {
162     while (true) {
163       val currentState = getTaskState<TaskResultT>(taskToken)
164       Log.d(TAG, "waitForCompleted#currentState: $currentState")
165       when (currentState) {
166         is OnboardingTaskState.Completed<*>,
167         is OnboardingTaskState.Failed<*> -> return currentState
168         else -> {
169           // Do nothing here as task is in progress.
170         }
171       }
172       // Sleep for a short interval before checking again.
173       Log.d(TAG, "waitForCompleted#sleep... 500 ms")
174       delay(500)
175     }
176   }
177 
waitForCompletedAsyncnull178   override fun <TaskResultT> waitForCompletedAsync(
179     taskToken: OnboardingTaskToken
180   ): ListenableFuture<OnboardingTaskState<TaskResultT>> =
181     coroutineScope.future { waitForCompleted(taskToken) }
182 
getContractAndTaskMapnull183   override fun getContractAndTaskMap():
184     ConcurrentHashMap<Class<out OnboardingTaskContract<*, *>>, Class<out OnboardingTask<*, *, *>>> =
185     contractAndTaskMap
186 
187   private suspend fun <
188     TaskArgsT,
189     TaskResultT,
190     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
191   > performTask(
192     taskContract: TaskContractT,
193     task: OnboardingTask<TaskArgsT, TaskResultT, TaskContractT>,
194     taskArgs: TaskArgsT,
195     taskToken: OnboardingTaskToken,
196   ) {
197     Log.d(TAG, "performTask#start")
198 
199     // Validate all inputs by the defined contract.
200     taskContract.validate(taskArgs)
201 
202     // Execute the task and await its completion.
203     val taskState = task.runTask(taskContract, taskArgs)
204 
205     Log.d(TAG, "performTask#end")
206 
207     // Update the tasksStates map with the actual task result after completion.
208     taskStateManager.updateTaskState(taskToken, taskState)
209   }
210 
211   private fun <
212     TaskArgsT,
213     TaskResultT,
214     TaskContractT : OnboardingTaskContract<TaskArgsT, TaskResultT>,
tryCreateTaskInstancenull215   > tryCreateTaskInstance(
216     contractClass: Class<out TaskContractT>
217   ): OnboardingTask<TaskArgsT, TaskResultT, TaskContractT>? {
218     val taskClass = contractAndTaskMap[contractClass] ?: return null
219 
220     try {
221       val constructor = taskClass.getDeclaredConstructor(Context::class.java)
222       // Create a new instance of the contract class using the constructor
223       @Suppress("UNCHECKED_CAST")
224       return constructor.newInstance(appContext)
225         as? OnboardingTask<TaskArgsT, TaskResultT, TaskContractT>
226     } catch (e: Exception) {
227       Log.w(TAG, "Error instantiating task: $e")
228     }
229     return null
230   }
231 
isTaskRunInSameProcessnull232   private fun isTaskRunInSameProcess(contract: OnboardingTaskContract<*, *>): Boolean {
233     val contractComponentName = OnboardingNode.extractComponentNameFromClass(contract::class.java)
234     return TextUtils.equals(componentName, "DefaultOnboardingTaskManager") ||
235       TextUtils.equals(componentName, contractComponentName)
236   }
237 
238   companion object {
239     private const val TAG: String = "OTMBase"
240   }
241 }
242