xref: /aosp_15_r20/external/tink/testing/go/streaming_aead_service.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15///////////////////////////////////////////////////////////////////////////////
16
17package services
18
19import (
20	"bytes"
21	"context"
22	"fmt"
23
24	"io"
25
26	"github.com/google/tink/go/streamingaead"
27	pb "github.com/google/tink/testing/go/protos/testing_api_go_grpc"
28)
29
30const (
31	decryptChunkSize = 2
32)
33
34// StreamingAEADService implements the StreamingAead testing service.
35type StreamingAEADService struct {
36	pb.StreamingAeadServer
37}
38
39func (s *StreamingAEADService) Create(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) {
40	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
41	if err != nil {
42		return &pb.CreationResponse{Err: err.Error()}, nil
43	}
44	_, err = streamingaead.New(handle)
45	if err != nil {
46		return &pb.CreationResponse{Err: err.Error()}, nil
47	}
48	return &pb.CreationResponse{}, nil
49}
50
51func (s *StreamingAEADService) Encrypt(ctx context.Context, req *pb.StreamingAeadEncryptRequest) (*pb.StreamingAeadEncryptResponse, error) {
52	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
53	if err != nil {
54		return &pb.StreamingAeadEncryptResponse{
55			Result: &pb.StreamingAeadEncryptResponse_Err{err.Error()}}, nil
56	}
57	cipher, err := streamingaead.New(handle)
58	if err != nil {
59		return &pb.StreamingAeadEncryptResponse{
60			Result: &pb.StreamingAeadEncryptResponse_Err{err.Error()}}, nil
61	}
62	ciphertextBuf := &bytes.Buffer{}
63	w, err := cipher.NewEncryptingWriter(ciphertextBuf, req.AssociatedData)
64	if err != nil {
65		errMsg := fmt.Sprintf("cannot create an encrypt writer: %v", err)
66		return &pb.StreamingAeadEncryptResponse{
67			Result: &pb.StreamingAeadEncryptResponse_Err{errMsg}}, nil
68	}
69	n, err := w.Write(req.Plaintext)
70	if err != nil {
71		errMsg := fmt.Sprintf("error writing to an encrypt writer: %v", err)
72		return &pb.StreamingAeadEncryptResponse{
73			Result: &pb.StreamingAeadEncryptResponse_Err{errMsg}}, nil
74	}
75	if n != len(req.Plaintext) {
76		errMsg := fmt.Sprintf("unexpected number of bytes written. Got=%d;want=%d", n, len(req.Plaintext))
77		return &pb.StreamingAeadEncryptResponse{
78			Result: &pb.StreamingAeadEncryptResponse_Err{errMsg}}, nil
79	}
80	if err := w.Close(); err != nil {
81		errMsg := fmt.Sprintf("error closing writer: %v", err)
82		return &pb.StreamingAeadEncryptResponse{
83			Result: &pb.StreamingAeadEncryptResponse_Err{errMsg}}, nil
84	}
85	return &pb.StreamingAeadEncryptResponse{
86		Result: &pb.StreamingAeadEncryptResponse_Ciphertext{ciphertextBuf.Bytes()}}, nil
87}
88
89func (s *StreamingAEADService) Decrypt(ctx context.Context, req *pb.StreamingAeadDecryptRequest) (*pb.StreamingAeadDecryptResponse, error) {
90	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
91	if err != nil {
92		return &pb.StreamingAeadDecryptResponse{
93			Result: &pb.StreamingAeadDecryptResponse_Err{err.Error()}}, nil
94	}
95	cipher, err := streamingaead.New(handle)
96	if err != nil {
97		return &pb.StreamingAeadDecryptResponse{
98			Result: &pb.StreamingAeadDecryptResponse_Err{err.Error()}}, nil
99	}
100	r, err := cipher.NewDecryptingReader(bytes.NewBuffer(req.Ciphertext), req.AssociatedData)
101	if err != nil {
102		errMsg := fmt.Sprintf("cannot create an encrypt reader: %v", err)
103		return &pb.StreamingAeadDecryptResponse{
104			Result: &pb.StreamingAeadDecryptResponse_Err{errMsg}}, nil
105	}
106	plaintextBuf := &bytes.Buffer{}
107	var (
108		chunk = make([]byte, decryptChunkSize)
109		eof   = false
110	)
111	for !eof {
112		n, err := r.Read(chunk)
113		if err != nil && err != io.EOF {
114			errMsg := fmt.Sprintf("error reading chunk: %v", err)
115			return &pb.StreamingAeadDecryptResponse{
116				Result: &pb.StreamingAeadDecryptResponse_Err{errMsg}}, nil
117		}
118		eof = err == io.EOF
119		plaintextBuf.Write(chunk[:n])
120	}
121	return &pb.StreamingAeadDecryptResponse{
122		Result: &pb.StreamingAeadDecryptResponse_Plaintext{plaintextBuf.Bytes()}}, nil
123}
124