xref: /aosp_15_r20/external/tink/go/hybrid/hybrid_decrypt_factory.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package hybrid
18
19import (
20	"fmt"
21
22	"github.com/google/tink/go/core/cryptofmt"
23	"github.com/google/tink/go/core/primitiveset"
24	"github.com/google/tink/go/internal/internalregistry"
25	"github.com/google/tink/go/internal/monitoringutil"
26	"github.com/google/tink/go/keyset"
27	"github.com/google/tink/go/monitoring"
28	"github.com/google/tink/go/tink"
29)
30
31// NewHybridDecrypt returns an HybridDecrypt primitive from the given keyset handle.
32func NewHybridDecrypt(handle *keyset.Handle) (tink.HybridDecrypt, error) {
33	ps, err := handle.Primitives()
34	if err != nil {
35		return nil, fmt.Errorf("hybrid_factory: cannot obtain primitive set: %s", err)
36	}
37	return newWrappedHybridDecrypt(ps)
38}
39
40// wrappedHybridDecrypt is an HybridDecrypt implementation that uses the underlying primitive set
41// for decryption.
42type wrappedHybridDecrypt struct {
43	ps     *primitiveset.PrimitiveSet
44	logger monitoring.Logger
45}
46
47// compile time assertion that wrappedHybridDecrypt implements the HybridDecrypt interface.
48var _ tink.HybridDecrypt = (*wrappedHybridDecrypt)(nil)
49
50func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet) (*wrappedHybridDecrypt, error) {
51	if err := isHybridDecrypt(ps.Primary.Primitive); err != nil {
52		return nil, err
53	}
54
55	for _, primitives := range ps.Entries {
56		for _, p := range primitives {
57			if err := isHybridDecrypt(p.Primitive); err != nil {
58				return nil, err
59			}
60		}
61	}
62	logger, err := createDecryptLogger(ps)
63	if err != nil {
64		return nil, err
65	}
66	return &wrappedHybridDecrypt{
67		ps:     ps,
68		logger: logger,
69	}, nil
70}
71
72func createDecryptLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, error) {
73	if len(ps.Annotations) == 0 {
74		return &monitoringutil.DoNothingLogger{}, nil
75	}
76	keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
77	if err != nil {
78		return nil, err
79	}
80	return internalregistry.GetMonitoringClient().NewLogger(&monitoring.Context{
81		KeysetInfo:  keysetInfo,
82		Primitive:   "hybrid_decrypt",
83		APIFunction: "decrypt",
84	})
85}
86
87// Decrypt decrypts the given ciphertext, verifying the integrity of contextInfo.
88// It returns the corresponding plaintext if the ciphertext is authenticated.
89func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte, error) {
90	// try non-raw keys
91	prefixSize := cryptofmt.NonRawPrefixSize
92	if len(ciphertext) > prefixSize {
93		prefix := ciphertext[:prefixSize]
94		ctNoPrefix := ciphertext[prefixSize:]
95		entries, err := a.ps.EntriesForPrefix(string(prefix))
96		if err == nil {
97			for i := 0; i < len(entries); i++ {
98				p := entries[i].Primitive.(tink.HybridDecrypt) // verified in newWrappedHybridDecrypt
99				pt, err := p.Decrypt(ctNoPrefix, contextInfo)
100				if err == nil {
101					a.logger.Log(entries[i].KeyID, len(ctNoPrefix))
102					return pt, nil
103				}
104			}
105		}
106	}
107
108	// try raw keys
109	entries, err := a.ps.RawEntries()
110	if err == nil {
111		for i := 0; i < len(entries); i++ {
112			p := entries[i].Primitive.(tink.HybridDecrypt) // verified in newWrappedHybridDecrypt
113			pt, err := p.Decrypt(ciphertext, contextInfo)
114			if err == nil {
115				a.logger.Log(entries[i].KeyID, len(ciphertext))
116				return pt, nil
117			}
118		}
119	}
120
121	// nothing worked
122	a.logger.LogFailure()
123	return nil, fmt.Errorf("hybrid_factory: decryption failed")
124}
125
126// Asserts `p` implements tink.HybridDecrypt and not tink.AEAD. The latter check
127// is required as implementations of tink.AEAD also satisfy tink.HybridDecrypt.
128func isHybridDecrypt(p any) error {
129	if _, ok := p.(tink.AEAD); ok {
130		return fmt.Errorf("hybrid_factory: tink.AEAD is not tink.HybridDecrypt")
131	}
132	if _, ok := p.(tink.HybridDecrypt); !ok {
133		return fmt.Errorf("hybrid_factory: not tink.HybridDecrypt")
134	}
135	return nil
136}
137