{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Handshake.Server.ClientHello13 (
processClientHello13,
) where
import qualified Data.ByteString as B
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.IO.Encode
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.Session
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types
processClientHello13
:: ServerParams
-> Context
-> ClientHello
-> IO
( Maybe KeyShareEntry
, (Cipher, Hash, Bool)
, (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
)
processClientHello13 :: ServerParams
-> Context
-> ClientHello
-> IO
(Maybe KeyShareEntry, (Cipher, Hash, Bool),
(SecretPair EarlySecret, [ExtensionRaw], Bool, Bool))
processClientHello13 ServerParams
sparams Context
ctx ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
..} = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
((ExtensionRaw -> Bool) -> [ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey) ([ExtensionRaw] -> Bool) -> [ExtensionRaw] -> Bool
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw] -> [ExtensionRaw]
forall a. HasCallStack => [a] -> [a]
init [ExtensionRaw]
chExtensions)
(IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
(TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"extension pre_shared_key must be last" AlertDescription
IllegalParameter
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Cipher] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Cipher]
ciphersFilteredVersion) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no cipher in common with the TLS 1.3 client" AlertDescription
HandshakeFailure
let usedCipher :: Cipher
usedCipher = ServerHooks -> Version -> [Cipher] -> Cipher
onCipherChoosing (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) Version
TLS13 [Cipher]
ciphersFilteredVersion
usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
rtt0 :: Bool
rtt0 =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> Bool
-> (EarlyDataIndication -> Bool)
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_EarlyData
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
Bool
False
(\(EarlyDataIndication Maybe Word32
_) -> Bool
True)
if Bool
rtt0
then
Context -> Established -> IO ()
setEstablished Context
ctx (Int -> Established
EarlyDataNotAllowed Int
3)
else
Context -> Established -> IO ()
setEstablished Context
ctx Established
NotEstablished
let require :: IO a
require =
TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
String
"key exchange not implemented, expected key_share extension"
AlertDescription
MissingExtension
extract :: KeyShare -> IO [KeyShareEntry]
extract (KeyShareClientHello [KeyShareEntry]
kses) = [KeyShareEntry] -> IO [KeyShareEntry]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [KeyShareEntry]
kses
extract KeyShare
_ = IO [KeyShareEntry]
forall {a}. IO a
require
keyShares <-
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO [KeyShareEntry]
-> (KeyShare -> IO [KeyShareEntry])
-> IO [KeyShareEntry]
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo ExtensionID
EID_KeyShare MessageType
MsgTClientHello [ExtensionRaw]
chExtensions IO [KeyShareEntry]
forall {a}. IO a
require KeyShare -> IO [KeyShareEntry]
extract
mshare <- findKeyShare keyShares serverGroups
let triple = (Cipher
usedCipher, Hash
usedHash, Bool
rtt0)
pskEarlySecret <- pskAndEarlySecret sparams ctx triple ch
(ich, b) <- fromJust <$> usingHState ctx getClientHello
updateTranscriptHash12 ctx (ClientHello ich, b)
return (mshare, triple, pskEarlySecret)
where
ciphersFilteredVersion :: [Cipher]
ciphersFilteredVersion = [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers [CipherId]
chCiphers [Cipher]
serverCiphers
serverCiphers :: [Cipher]
serverCiphers =
(Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter
(Version -> Cipher -> Bool
cipherAllowedForVersion Version
TLS13)
(Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams)
serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
ks [Group]
ggs = [Group] -> IO (Maybe KeyShareEntry)
forall {m :: * -> *}.
MonadIO m =>
[Group] -> m (Maybe KeyShareEntry)
go [Group]
ggs
where
go :: [Group] -> m (Maybe KeyShareEntry)
go [] = Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KeyShareEntry
forall a. Maybe a
Nothing
go (Group
g : [Group]
gs) = case (KeyShareEntry -> Bool) -> [KeyShareEntry] -> [KeyShareEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> KeyShareEntry -> Bool
grpEq Group
g) [KeyShareEntry]
ks of
[] -> [Group] -> m (Maybe KeyShareEntry)
go [Group]
gs
[KeyShareEntry
k] -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyShareEntry -> Bool
checkKeyShareKeyLength KeyShareEntry
k) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
IllegalParameter
Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry -> m (Maybe KeyShareEntry))
-> Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ KeyShareEntry -> Maybe KeyShareEntry
forall a. a -> Maybe a
Just KeyShareEntry
k
[KeyShareEntry]
_ -> TLSError -> m (Maybe KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe KeyShareEntry))
-> TLSError -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"duplicated key_share" AlertDescription
IllegalParameter
grpEq :: Group -> KeyShareEntry -> Bool
grpEq Group
g KeyShareEntry
ent = Group
g Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
== KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ent
pskAndEarlySecret
:: ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret :: ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret ServerParams
sparams Context
ctx (Cipher
usedCipher, Hash
usedHash, Bool
rtt0) CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} = do
(psk, binderInfo, is0RTTvalid) <- IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK
earlyKey <- calculateEarlySecret ctx choice (Left psk)
let earlySecret = SecretPair EarlySecret -> BaseSecret EarlySecret
forall a. SecretPair a -> BaseSecret a
pairBase SecretPair EarlySecret
earlyKey
authenticated = Maybe (ByteString, Int, Int) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (ByteString, Int, Int)
binderInfo
preSharedKeyExt <- checkBinder earlySecret binderInfo
return (earlyKey, preSharedKeyExt, authenticated, is0RTTvalid)
where
choice :: CipherChoice
choice = Version -> Cipher -> CipherChoice
makeCipherChoice Version
TLS13 Cipher
usedCipher
choosePSK :: IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
-> (PreSharedKey
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool))
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
ExtensionID
EID_PreSharedKey
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
((ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False))
PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK
selectPSK :: PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK (PreSharedKeyClientHello (PskIdentity ByteString
identity Word32
obfAge : [PskIdentity]
_) bnds :: [ByteString]
bnds@(ByteString
bnd : [ByteString]
_)) = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PskKexMode] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PskKexMode]
dhModes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no psk_key_exchange_modes extension" AlertDescription
MissingExtension
if PskKexMode
PSK_DHE_KE PskKexMode -> [PskKexMode] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PskKexMode]
dhModes
then do
let len :: Int
len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\ByteString
x -> ByteString -> Int
B.length ByteString
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [ByteString]
bnds) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
mgr :: SessionManager
mgr = Shared -> SessionManager
sharedSessionManager (Shared -> SessionManager) -> Shared -> SessionManager
forall a b. (a -> b) -> a -> b
$ ServerParams -> Shared
serverShared ServerParams
sparams
msdata <-
if Bool
rtt0
then SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResumeOnlyOnce SessionManager
mgr ByteString
identity
else SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResume SessionManager
mgr ByteString
identity
case msdata of
Just SessionData
sdata -> do
let tinfo :: TLS13TicketInfo
tinfo = Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe TLS13TicketInfo -> TLS13TicketInfo)
-> Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a b. (a -> b) -> a -> b
$ SessionData -> Maybe TLS13TicketInfo
sessionTicketInfo SessionData
sdata
psk :: ByteString
psk = SessionData -> ByteString
sessionSecret SessionData
sdata
isFresh <- TLS13TicketInfo -> Word32 -> IO Bool
checkFreshness TLS13TicketInfo
tinfo Word32
obfAge
(isPSKvalid, is0RTTvalid) <- checkSessionEquality sdata
if isPSKvalid && isFresh
then return (psk, Just (bnd, 0 :: Int, len), is0RTTvalid)
else
return (zero, Nothing, False)
Maybe SessionData
_ -> (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
else (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
selectPSK PreSharedKey
_ = (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
checkBinder :: BaseSecret EarlySecret
-> Maybe (ByteString, a, Int) -> m [ExtensionRaw]
checkBinder BaseSecret EarlySecret
_ Maybe (ByteString, a, Int)
Nothing = [ExtensionRaw] -> m [ExtensionRaw]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
checkBinder BaseSecret EarlySecret
earlySecret (Just (ByteString
binder, a
n, Int
tlen)) = do
(_, b) <- Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString])
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ClientHello, [ByteString]) -> (ClientHello, [ByteString]))
-> m (Maybe (ClientHello, [ByteString]))
-> m (ClientHello, [ByteString])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context
-> HandshakeM (Maybe (ClientHello, [ByteString]))
-> m (Maybe (ClientHello, [ByteString]))
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe (ClientHello, [ByteString]))
getClientHello
let binder' = BaseSecret EarlySecret -> Hash -> Int -> ByteString -> ByteString
makePSKBinder BaseSecret EarlySecret
earlySecret Hash
usedHash Int
tlen (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [ByteString]
b
unless (binder == binder') $
decryptError "PSK binder validation failed"
return [toExtensionRaw $ PreSharedKeyServerHello $ fromIntegral n]
checkSessionEquality :: SessionData -> IO (Bool, Bool)
checkSessionEquality SessionData
sdata = do
msni <- Context -> TLSSt (Maybe String) -> IO (Maybe String)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe String)
getClientSNI
let isSameSNI = SessionData -> Maybe String
sessionClientSNI SessionData
sdata Maybe String -> Maybe String -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe String
msni
isSameCipher = SessionData -> Word16
sessionCipher SessionData
sdata Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> Word16
cipherID Cipher
usedCipher
ciphers = Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams
scid = SessionData -> Word16
sessionCipher SessionData
sdata
isSameKDF = case Word16 -> [Cipher] -> Maybe Cipher
findCipher Word16
scid [Cipher]
ciphers of
Maybe Cipher
Nothing -> Bool
False
Just Cipher
c -> Cipher -> Hash
cipherHash Cipher
c Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> Hash
cipherHash Cipher
usedCipher
isSameVersion = Version
TLS13 Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== SessionData -> Version
sessionVersion SessionData
sdata
isPSKvalid = Bool
isSameKDF Bool -> Bool -> Bool
&& Bool
isSameSNI
is0RTTvalid = Bool
isSameVersion Bool -> Bool -> Bool
&& Bool
isSameCipher
return (isPSKvalid, is0RTTvalid)
dhModes :: [PskKexMode]
dhModes =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [PskKexMode]
-> (PskKeyExchangeModes -> [PskKexMode])
-> [PskKexMode]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_PskKeyExchangeModes
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
[]
(\(PskKeyExchangeModes [PskKexMode]
ms) -> [PskKexMode]
ms)
hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash
zero :: ByteString
zero = Int -> CompressionID -> ByteString
B.replicate Int
hashSize CompressionID
0