Fix ActivityResultRegistry save/restore

The ActivityResultRegistry is currently restored after the on
ComponentActivity call to super.onCreate(). This means that any
callbacks registered before that could have already been saved waiting
to be restored so we end up improperly registering several duplicate
callbacks.

We should handle this on two fronts, one in ActivityResultRegistry for
cases outside of ComponentActivity and inside of ComponentActivity.

For ComponentActivity we should  save earlier as part of an OnContextAvailableListener to ensure that we restore the callbacks as soon as possible.

For ActivityResultRegistry, we should make sure that even if we register
before a call on onRestoreInstanceState, we don't duplicate keys.
Specifically, we need to restore the old requestCode to ensure that any
launched results that could potentially be dispatched still have a valid
requestCode, otherwise the dispatch would be ignored.

RelNote: "The ActivityResultRegistry callbacks are now properly saved
and restored so callbacks are not duplicated in the savedState."
Test: Added new test
Bug: 191893160

Change-Id: I9781617370ad24f768249df42d2ab148915097cb
(cherry picked from commit 82faca9ddbd48485b8752cc702891e45a00bcd92)
diff --git a/activity/activity/src/androidTest/AndroidManifest.xml b/activity/activity/src/androidTest/AndroidManifest.xml
index 6b8daa2..dc15880 100644
--- a/activity/activity/src/androidTest/AndroidManifest.xml
+++ b/activity/activity/src/androidTest/AndroidManifest.xml
@@ -32,6 +32,9 @@
         <activity android:name="androidx.activity.ResultComponentActivity"/>
         <activity android:name="androidx.activity.ResumeViewModelActivity" />
         <activity android:name="androidx.activity.PassThroughActivity"/>
+        <activity android:name="androidx.activity.RegisterInInitActivity"/>
+        <activity android:name="androidx.activity.RegisterBeforeOnCreateActivity"/>
+        <activity android:name="androidx.activity.FinishActivity"/>
     </application>
 
 </manifest>
diff --git a/activity/activity/src/androidTest/java/androidx/activity/ComponentActivityResultTest.kt b/activity/activity/src/androidTest/java/androidx/activity/ComponentActivityResultTest.kt
index 489dd7eb..22675c7 100644
--- a/activity/activity/src/androidTest/java/androidx/activity/ComponentActivityResultTest.kt
+++ b/activity/activity/src/androidTest/java/androidx/activity/ComponentActivityResultTest.kt
@@ -20,6 +20,7 @@
 import android.content.Intent
 import android.os.Bundle
 import androidx.activity.result.ActivityResult
+import androidx.activity.result.ActivityResultLauncher
 import androidx.activity.result.ActivityResultRegistry
 import androidx.activity.result.contract.ActivityResultContract
 import androidx.activity.result.contract.ActivityResultContracts.StartActivityForResult
@@ -55,6 +56,37 @@
             }
         }
     }
+
+    @Test
+    fun registerBeforeOnCreateTest() {
+        ActivityScenario.launch(RegisterBeforeOnCreateActivity::class.java).use { scenario ->
+            scenario.withActivity {
+                recreate()
+                launcher.launch(Intent(this, FinishActivity::class.java))
+            }
+
+            scenario.withActivity { }
+
+            scenario.withActivity {
+                assertThat(firstLaunchCount).isEqualTo(0)
+                assertThat(secondLaunchCount).isEqualTo(1)
+            }
+        }
+    }
+
+    @Test
+    fun registerInInitTest() {
+        ActivityScenario.launch(RegisterInInitActivity::class.java).use { scenario ->
+            scenario.withActivity {
+                recreate()
+                launcher.launch(Intent(this, FinishActivity::class.java))
+            }
+
+            scenario.withActivity {
+                assertThat(launchCount).isEqualTo(1)
+            }
+        }
+    }
 }
 
 class PassThroughActivity : ComponentActivity() {
@@ -92,3 +124,49 @@
         launcher.launch(Intent())
     }
 }
+
+class RegisterBeforeOnCreateActivity : ComponentActivity() {
+    lateinit var launcher: ActivityResultLauncher<Intent>
+    var firstLaunchCount = 0
+    var secondLaunchCount = 0
+    var recreated = false
+
+    init {
+        addOnContextAvailableListener {
+            launcher = if (!recreated) {
+                registerForActivityResult(StartActivityForResult()) {
+                    firstLaunchCount++
+                }
+            } else {
+                registerForActivityResult(StartActivityForResult()) {
+                    secondLaunchCount++
+                }
+            }
+        }
+    }
+
+    override fun onCreate(savedInstanceState: Bundle?) {
+        if (savedInstanceState != null) {
+            recreated = true
+        }
+        super.onCreate(savedInstanceState)
+    }
+}
+
+class RegisterInInitActivity : ComponentActivity() {
+    var launcher: ActivityResultLauncher<Intent>
+    var launchCount = 0
+
+    init {
+        launcher = registerForActivityResult(StartActivityForResult()) {
+            launchCount++
+        }
+    }
+}
+
+class FinishActivity : ComponentActivity() {
+    override fun onCreate(savedInstanceState: Bundle?) {
+        super.onCreate(savedInstanceState)
+        finish()
+    }
+}
diff --git a/activity/activity/src/androidTest/java/androidx/activity/result/ActivityResultRegistryTest.kt b/activity/activity/src/androidTest/java/androidx/activity/result/ActivityResultRegistryTest.kt
index 5aad19f..3c8c327 100644
--- a/activity/activity/src/androidTest/java/androidx/activity/result/ActivityResultRegistryTest.kt
+++ b/activity/activity/src/androidTest/java/androidx/activity/result/ActivityResultRegistryTest.kt
@@ -360,6 +360,35 @@
     }
 
     @Test
+    fun testRegisterBeforeRestoreInstanceState() {
+        registry.register("key", StartActivityForResult()) { }
+
+        val savedState = Bundle()
+        registry.onSaveInstanceState(savedState)
+
+        val restoredRegistry = object : ActivityResultRegistry() {
+            override fun <I : Any?, O : Any?> onLaunch(
+                requestCode: Int,
+                contract: ActivityResultContract<I, O>,
+                input: I,
+                options: ActivityOptionsCompat?
+            ) {
+                dispatchResult(requestCode, RESULT_OK, Intent())
+            }
+        }
+
+        restoredRegistry.register("key", StartActivityForResult()) { }
+        restoredRegistry.onRestoreInstanceState(savedState)
+
+        val newSavedState = Bundle()
+        restoredRegistry.onSaveInstanceState(newSavedState)
+
+        val keys = newSavedState.getStringArrayList("KEY_COMPONENT_ACTIVITY_REGISTERED_KEYS")
+
+        assertThat(keys?.size).isEqualTo(1)
+    }
+
+    @Test
     fun testKeepKeyAfterLaunch() {
         var code = 0
         val noDispatchRegistry = object : ActivityResultRegistry() {
@@ -416,4 +445,52 @@
 
         assertThat(callbackExecuted).isTrue()
     }
+
+    @Test
+    fun testSavePendingOnRestore() {
+        var code = 0
+        val noDispatchRegistry = object : ActivityResultRegistry() {
+            override fun <I : Any?, O : Any?> onLaunch(
+                requestCode: Int,
+                contract: ActivityResultContract<I, O>,
+                input: I,
+                options: ActivityOptionsCompat?
+            ) {
+                code = requestCode
+            }
+        }
+
+        val contract = StartActivityForResult()
+        val launcher = noDispatchRegistry.register("key", contract) { }
+
+        launcher.launch(Intent())
+        launcher.unregister()
+
+        noDispatchRegistry.dispatchResult(code, RESULT_OK, Intent())
+
+        val savedState = Bundle()
+        noDispatchRegistry.onSaveInstanceState(savedState)
+
+        val newNoDispatchRegistry = object : ActivityResultRegistry() {
+            override fun <I : Any?, O : Any?> onLaunch(
+                requestCode: Int,
+                contract: ActivityResultContract<I, O>,
+                input: I,
+                options: ActivityOptionsCompat?
+            ) {
+                code = requestCode
+            }
+        }
+
+        var completedLaunch = false
+        newNoDispatchRegistry.register("key", contract) {
+            completedLaunch = true
+        }
+
+        newNoDispatchRegistry.onRestoreInstanceState(savedState)
+
+        newNoDispatchRegistry.dispatchResult(code, RESULT_OK, Intent())
+
+        assertThat(completedLaunch).isTrue()
+    }
 }
\ No newline at end of file
diff --git a/activity/activity/src/main/java/androidx/activity/ComponentActivity.java b/activity/activity/src/main/java/androidx/activity/ComponentActivity.java
index 5ef108c..d567ec74 100644
--- a/activity/activity/src/main/java/androidx/activity/ComponentActivity.java
+++ b/activity/activity/src/main/java/androidx/activity/ComponentActivity.java
@@ -106,6 +106,8 @@
         ViewModelStore viewModelStore;
     }
 
+    private static final String ACTIVITY_RESULT_TAG = "android:support:activity-result";
+
     final ContextAwareHelper mContextAwareHelper = new ContextAwareHelper();
     private final LifecycleRegistry mLifecycleRegistry = new LifecycleRegistry(this);
     @SuppressWarnings("WeakerAccess") /* synthetic access */
@@ -265,6 +267,19 @@
         if (19 <= SDK_INT && SDK_INT <= 23) {
             getLifecycle().addObserver(new ImmLeaksCleaner(this));
         }
+        getSavedStateRegistry().registerSavedStateProvider(ACTIVITY_RESULT_TAG,
+                () -> {
+                    Bundle outState = new Bundle();
+                    mActivityResultRegistry.onSaveInstanceState(outState);
+                    return outState;
+                });
+        addOnContextAvailableListener(context -> {
+            Bundle savedInstanceState = getSavedStateRegistry()
+                    .consumeRestoredStateForKey(ACTIVITY_RESULT_TAG);
+            if (savedInstanceState != null) {
+                mActivityResultRegistry.onRestoreInstanceState(savedInstanceState);
+            }
+        });
     }
 
     /**
@@ -296,7 +311,6 @@
         mSavedStateRegistryController.performRestore(savedInstanceState);
         mContextAwareHelper.dispatchOnContextAvailable(this);
         super.onCreate(savedInstanceState);
-        mActivityResultRegistry.onRestoreInstanceState(savedInstanceState);
         ReportFragment.injectIfNeededIn(this);
         if (mContentLayoutId != 0) {
             setContentView(mContentLayoutId);
@@ -312,7 +326,6 @@
         }
         super.onSaveInstanceState(outState);
         mSavedStateRegistryController.performSave(outState);
-        mActivityResultRegistry.onSaveInstanceState(outState);
     }
 
     /**
diff --git a/activity/activity/src/main/java/androidx/activity/result/ActivityResultRegistry.java b/activity/activity/src/main/java/androidx/activity/result/ActivityResultRegistry.java
index afa2c453..d7fb56a 100644
--- a/activity/activity/src/main/java/androidx/activity/result/ActivityResultRegistry.java
+++ b/activity/activity/src/main/java/androidx/activity/result/ActivityResultRegistry.java
@@ -67,7 +67,7 @@
     private Random mRandom = new Random();
 
     private final Map<Integer, String> mRcToKey = new HashMap<>();
-    private final Map<String, Integer> mKeyToRc = new HashMap<>();
+    final Map<String, Integer> mKeyToRc = new HashMap<>();
     private final Map<String, LifecycleContainer> mKeyToLifecycleContainers = new HashMap<>();
     ArrayList<String> mLaunchedKeys = new ArrayList<>();
 
@@ -163,7 +163,8 @@
             @Override
             public void launch(I input, @Nullable ActivityOptionsCompat options) {
                 mLaunchedKeys.add(key);
-                onLaunch(requestCode, contract, input, options);
+                Integer innerCode = mKeyToRc.get(key);
+                onLaunch((innerCode != null) ? innerCode : requestCode, contract, input, options);
             }
 
             @Override
@@ -221,7 +222,8 @@
             @Override
             public void launch(I input, @Nullable ActivityOptionsCompat options) {
                 mLaunchedKeys.add(key);
-                onLaunch(requestCode, contract, input, options);
+                Integer innerCode = mKeyToRc.get(key);
+                onLaunch((innerCode != null) ? innerCode : requestCode, contract, input, options);
             }
 
             @Override
@@ -277,9 +279,9 @@
      */
     public final void onSaveInstanceState(@NonNull Bundle outState) {
         outState.putIntegerArrayList(KEY_COMPONENT_ACTIVITY_REGISTERED_RCS,
-                new ArrayList<>(mRcToKey.keySet()));
+                new ArrayList<>(mKeyToRc.values()));
         outState.putStringArrayList(KEY_COMPONENT_ACTIVITY_REGISTERED_KEYS,
-                new ArrayList<>(mRcToKey.values()));
+                new ArrayList<>(mKeyToRc.keySet()));
         outState.putStringArrayList(KEY_COMPONENT_ACTIVITY_LAUNCHED_KEYS,
                 new ArrayList<>(mLaunchedKeys));
         outState.putBundle(KEY_COMPONENT_ACTIVITY_PENDING_RESULTS,
@@ -303,15 +305,28 @@
         if (keys == null || rcs == null) {
             return;
         }
-        int numKeys = keys.size();
-        for (int i = 0; i < numKeys; i++) {
-            bindRcKey(rcs.get(i), keys.get(i));
-        }
         mLaunchedKeys =
                 savedInstanceState.getStringArrayList(KEY_COMPONENT_ACTIVITY_LAUNCHED_KEYS);
         mRandom = (Random) savedInstanceState.getSerializable(KEY_COMPONENT_ACTIVITY_RANDOM_OBJECT);
         mPendingResults.putAll(
                 savedInstanceState.getBundle(KEY_COMPONENT_ACTIVITY_PENDING_RESULTS));
+        for (int i = 0; i < keys.size(); i++) {
+            String key = keys.get(i);
+            // Developers may have already registered with this same key by the time we restore
+            // state, which caused us to generate a new requestCode that doesn't match what we're
+            // about to restore. Clear out the new requestCode to ensure that we use the
+            // previously saved requestCode.
+            if (mKeyToRc.containsKey(key)) {
+                Integer newRequestCode = mKeyToRc.remove(key);
+                // On the chance that developers have already called launch() with this new
+                // requestCode, keep the mapping around temporarily to ensure the result is
+                // properly delivered to both the new requestCode and the restored requestCode
+                if (!mPendingResults.containsKey(key)) {
+                    mRcToKey.remove(newRequestCode);
+                }
+            }
+            bindRcKey(rcs.get(i), keys.get(i));
+        }
     }
 
     /**