1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */
15 
16 package software.amazon.awssdk.services.sns;
17 
18 import com.fasterxml.jackson.core.JsonFactory;
19 import com.fasterxml.jackson.core.JsonParseException;
20 import com.fasterxml.jackson.core.JsonParser;
21 import com.fasterxml.jackson.core.JsonToken;
22 import java.io.IOException;
23 import java.security.InvalidKeyException;
24 import java.security.NoSuchAlgorithmException;
25 import java.security.PublicKey;
26 import java.security.Signature;
27 import java.security.SignatureException;
28 import java.util.Arrays;
29 import java.util.HashMap;
30 import java.util.HashSet;
31 import java.util.Map;
32 import java.util.Set;
33 import java.util.SortedMap;
34 import java.util.TreeMap;
35 import software.amazon.awssdk.utils.BinaryUtils;
36 
37 /**
38  * Utility for validating signatures on a Simple Notification Service JSON message.
39  */
40 public class SignatureChecker {
41 
42     private static final String NOTIFICATION_TYPE = "Notification";
43     private static final String SUBSCRIBE_TYPE = "SubscriptionConfirmation";
44     private static final String UNSUBSCRIBE_TYPE = "UnsubscribeConfirmation";
45     private static final String TYPE = "Type";
46     private static final String SUBSCRIBE_URL = "SubscribeURL";
47     private static final String MESSAGE = "Message";
48     private static final String TIMESTAMP = "Timestamp";
49     private static final String SIGNATURE_VERSION = "SignatureVersion";
50     private static final String SIGNATURE = "Signature";
51     private static final String MESSAGE_ID = "MessageId";
52     private static final String SUBJECT = "Subject";
53     private static final String TOPIC = "TopicArn";
54     private static final String TOKEN = "Token";
55     private static final Set<String> INTERESTING_FIELDS = new HashSet<>(Arrays.asList(TYPE, SUBSCRIBE_URL, MESSAGE, TIMESTAMP,
56                                                                                       SIGNATURE, SIGNATURE_VERSION, MESSAGE_ID,
57                                                                                       SUBJECT, TOPIC, TOKEN));
58     private Signature sigChecker;
59 
60     /**
61      * Validates the signature on a Simple Notification Service message. No
62      * Amazon-specific dependencies, just plain Java crypto and Jackson for
63      * parsing
64      *
65      * @param message
66      *            A JSON-encoded Simple Notification Service message. Note: the
67      *            JSON may be only one level deep.
68      * @param publicKey
69      *            The Simple Notification Service public key, exactly as you'd
70      *            see it when retrieved from the cert.
71      *
72      * @return True if the message was correctly validated, otherwise false.
73      */
verifyMessageSignature(String message, PublicKey publicKey)74     public boolean verifyMessageSignature(String message, PublicKey publicKey) {
75 
76         // extract the type and signature parameters
77         Map<String, String> parsed = parseJson(message);
78 
79         return verifySignature(parsed, publicKey);
80     }
81 
82     /**
83      * Validates the signature on a Simple Notification Service message. No
84      * Amazon-specific dependencies, just plain Java crypto
85      *
86      * @param parsedMessage
87      *            A map of Simple Notification Service message.
88      * @param publicKey
89      *            The Simple Notification Service public key, exactly as you'd
90      *            see it when retrieved from the cert.
91      *
92      * @return True if the message was correctly validated, otherwise false.
93      */
verifySignature(Map<String, String> parsedMessage, PublicKey publicKey)94     public boolean verifySignature(Map<String, String> parsedMessage, PublicKey publicKey) {
95         boolean valid = false;
96         String version = parsedMessage.get(SIGNATURE_VERSION);
97         if (version.equals("1")) {
98             // construct the canonical signed string
99             String type = parsedMessage.get(TYPE);
100             String signature = parsedMessage.get(SIGNATURE);
101             String signed = "";
102             if (type.equals(NOTIFICATION_TYPE)) {
103                 signed = stringToSign(publishMessageValues(parsedMessage));
104             } else if (type.equals(SUBSCRIBE_TYPE)) {
105                 signed = stringToSign(subscribeMessageValues(parsedMessage));
106             } else if (type.equals(UNSUBSCRIBE_TYPE)) {
107                 signed = stringToSign(subscribeMessageValues(parsedMessage)); // no difference, for now
108             } else {
109                 throw new RuntimeException("Cannot process message of type " + type);
110             }
111             valid = verifySignature(signed, signature, publicKey);
112         }
113         return valid;
114     }
115 
116     /**
117      * Does the actual Java cryptographic verification of the signature. This
118      * method does no handling of the many rare exceptions it is required to
119      * catch.
120      *
121      * This can also be used to verify the signature from the x-amz-sns-signature http header
122      *
123      * @param message
124      *            Exact string that was signed.  In the case of the x-amz-sns-signature header the
125      *            signing string is the entire post body
126      * @param signature
127      *            Base64-encoded signature of the message
128      */
verifySignature(String message, String signature, PublicKey publicKey)129     public boolean verifySignature(String message, String signature, PublicKey publicKey) {
130         boolean result = false;
131         byte[] sigbytes = null;
132         try {
133             sigbytes = BinaryUtils.fromBase64Bytes(signature.getBytes());
134             sigChecker = Signature.getInstance("SHA1withRSA"); //check the signature
135             sigChecker.initVerify(publicKey);
136             sigChecker.update(message.getBytes());
137             result = sigChecker.verify(sigbytes);
138         } catch (NoSuchAlgorithmException e) {
139             // Rare exception: JVM does not support SHA1 with RSA
140         } catch (InvalidKeyException e) {
141             // Rare exception: The private key was incorrectly formatted
142         } catch (SignatureException e) {
143             // Rare exception: Catch-all exception for the signature checker
144         }
145         return result;
146     }
147 
stringToSign(SortedMap<String, String> signables)148     protected String stringToSign(SortedMap<String, String> signables) {
149         // each key and value is followed by a newline
150         StringBuilder sb = new StringBuilder();
151         for (String k : signables.keySet()) {
152             sb.append(k).append("\n");
153             sb.append(signables.get(k)).append("\n");
154         }
155         String result = sb.toString();
156         return result;
157     }
158 
parseJson(String jsonmessage)159     private Map<String, String> parseJson(String jsonmessage) {
160         Map<String, String> parsed = new HashMap<String, String>();
161         JsonFactory jf = new JsonFactory();
162         try {
163             JsonParser parser = jf.createParser(jsonmessage);
164             parser.nextToken(); //shift past the START_OBJECT that begins the JSON
165             while (parser.nextToken() != JsonToken.END_OBJECT) {
166                 String fieldname = parser.getCurrentName();
167                 if (!INTERESTING_FIELDS.contains(fieldname)) {
168                     parser.skipChildren();
169                     continue;
170                 }
171                 parser.nextToken(); // move to value, or START_OBJECT/START_ARRAY
172                 String value;
173                 if (parser.getCurrentToken() == JsonToken.START_ARRAY) {
174                     value = "";
175                     boolean first = true;
176                     while (parser.nextToken() != JsonToken.END_ARRAY) {
177                         if (!first) {
178                             value += ",";
179                         }
180                         first = false;
181                         value += parser.getText();
182                     }
183                 } else {
184                     value = parser.getText();
185                 }
186                 parsed.put(fieldname, value);
187             }
188         } catch (JsonParseException e) {
189             // JSON could not be parsed
190             e.printStackTrace();
191         } catch (IOException e) {
192             // Rare exception
193         }
194         return parsed;
195     }
196 
publishMessageValues(Map<String, String> parsedMessage)197     private TreeMap<String, String> publishMessageValues(Map<String, String> parsedMessage) {
198         TreeMap<String, String> signables = new TreeMap<String, String>();
199         String[] keys = {MESSAGE, MESSAGE_ID, SUBJECT, TYPE, TIMESTAMP, TOPIC};
200         for (String key : keys) {
201             if (parsedMessage.containsKey(key)) {
202                 signables.put(key, parsedMessage.get(key));
203             }
204         }
205         return signables;
206     }
207 
subscribeMessageValues(Map<String, String> parsedMessage)208     private TreeMap<String, String> subscribeMessageValues(Map<String, String> parsedMessage) {
209         TreeMap<String, String> signables = new TreeMap<String, String>();
210         String[] keys = {SUBSCRIBE_URL, MESSAGE, MESSAGE_ID, TYPE, TIMESTAMP, TOKEN, TOPIC};
211         for (String key : keys) {
212             if (parsedMessage.containsKey(key)) {
213                 signables.put(key, parsedMessage.get(key));
214             }
215         }
216         return signables;
217     }
218 }
219