{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.IO.Decode (
    decodePacket12,
    decodePacket13,
) where

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import qualified Data.ByteString as BS

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Util
import Network.TLS.Wire

decodePacket12 :: Context -> Record Plaintext -> IO (Either TLSError Packet)
decodePacket12 :: Context -> Record Plaintext -> IO (Either TLSError Packet)
decodePacket12 Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData (ByteString -> Packet) -> ByteString -> Packet
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
decodePacket12 Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert ([(AlertLevel, AlertDescription)] -> Packet)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
decodePacket12 Context
ctx (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right ()
_ -> do
            Context -> IO ()
switchRxEncryption Context
ctx
            Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right Packet
ChangeCipherSpec
decodePacket12 Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
ver Fragment Plaintext
fragment) = do
    keyxchg <-
        Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx IO (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe CipherKeyExchangeType))
-> IO (Maybe CipherKeyExchangeType)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe HandshakeState
hs -> Maybe CipherKeyExchangeType -> IO (Maybe CipherKeyExchangeType)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs Maybe HandshakeState
-> (HandshakeState -> Maybe Cipher) -> Maybe Cipher
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher Maybe Cipher
-> (Cipher -> Maybe CipherKeyExchangeType)
-> Maybe CipherKeyExchangeType
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CipherKeyExchangeType -> Maybe CipherKeyExchangeType
forall a. a -> Maybe a
Just (CipherKeyExchangeType -> Maybe CipherKeyExchangeType)
-> (Cipher -> CipherKeyExchangeType)
-> Cipher
-> Maybe CipherKeyExchangeType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
    usingState ctx $ do
        let currentParams =
                CurrentParams
                    { cParamsVersion :: Version
cParamsVersion = Version
ver
                    , cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
                    }
        -- get back the optional continuation, and parse as many handshake record as possible.
        (mCont, wirebytes) <- gets stHandshakeRecordCont12
        modify' (\TLSState
st -> TLSState
st{stHandshakeRecordCont12 = (Nothing, [])})
        (hss, bss) <-
            unzip <$> parseMany currentParams mCont wirebytes (fragmentGetBytes fragment)
        return $ Handshake hss bss
  where
    parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> WireBytes
-> ByteString
-> m [(Handshake, WireBytes)]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont WireBytes
wirebytes ByteString
bs =
        case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
            GotError TLSError
err -> TLSError -> m [(Handshake, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
            GotPartial GetContinuation (HandshakeType, ByteString)
cont -> do
                (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (\TLSState
st -> TLSState
st{stHandshakeRecordCont12 = (Just cont, bs : wirebytes)})
                [(Handshake, WireBytes)] -> m [(Handshake, WireBytes)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
            GotSuccess (HandshakeType
ty, ByteString
content) ->
                case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [(Handshake, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake
h -> [(Handshake, WireBytes)] -> m [(Handshake, WireBytes)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [(Handshake
h, WireBytes -> WireBytes
forall a. [a] -> [a]
reverse (ByteString
bs ByteString -> WireBytes -> WireBytes
forall a. a -> [a] -> [a]
: WireBytes
wirebytes))]
            GotSuccessRemaining (HandshakeType
ty, ByteString
content) ByteString
left ->
                case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [(Handshake, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake
h -> do
                        hbs <- CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> WireBytes
-> ByteString
-> m [(Handshake, WireBytes)]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing [] ByteString
left
                        let len = ByteString -> Int
BS.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
left
                            bs' = Int -> ByteString -> ByteString
BS.take Int
len ByteString
bs
                        return ((h, reverse (bs' : wirebytes)) : hbs)
decodePacket12 Context
_ Record Plaintext
_ = Either TLSError Packet -> IO (Either TLSError Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing String
"unknown protocol type")

switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption Context
ctx =
    Context -> HandshakeM (Maybe RecordState) -> IO (Maybe RecordState)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) IO (Maybe RecordState) -> (Maybe RecordState -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe RecordState
rx ->
        MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxRecordState Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState -> IO RecordState) -> RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ Maybe RecordState -> RecordState
forall a. HasCallStack => Maybe a -> a
fromJust Maybe RecordState
rx)

----------------------------------------------------------------

decodePacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
decodePacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
decodePacket13 Context
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet13
forall a b. a -> Either a b
Left TLSError
err
        Right ()
_ -> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right Packet13
ChangeCipherSpec13
decodePacket13 Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ Packet13 -> Either TLSError Packet13
forall a b. b -> Either a b
Right (Packet13 -> Either TLSError Packet13)
-> Packet13 -> Either TLSError Packet13
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet13
AppData13 (ByteString -> Packet13) -> ByteString -> Packet13
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
decodePacket13 Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet13
Alert13 ([(AlertLevel, AlertDescription)] -> Packet13)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet13
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
decodePacket13 Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
_ Fragment Plaintext
fragment) = Context -> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet13 -> IO (Either TLSError Packet13))
-> TLSSt Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ do
    (mCont, wirebytes) <- (TLSState
 -> (Maybe (GetContinuation (HandshakeType, ByteString)),
     WireBytes))
-> TLSSt
     (Maybe (GetContinuation (HandshakeType, ByteString)), WireBytes)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState
-> (Maybe (GetContinuation (HandshakeType, ByteString)), WireBytes)
stHandshakeRecordCont13
    modify' (\TLSState
st -> TLSState
st{stHandshakeRecordCont13 = (Nothing, [])})
    (hss, bss) <- unzip <$> parseMany mCont wirebytes (fragmentGetBytes fragment)
    return $ Handshake13 hss bss
  where
    parseMany :: Maybe (GetContinuation (HandshakeType, ByteString))
-> WireBytes -> ByteString -> m [(Handshake13, WireBytes)]
parseMany Maybe (GetContinuation (HandshakeType, ByteString))
mCont WireBytes
wirebytes ByteString
bs =
        case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord13 Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
            GotError TLSError
err -> TLSError -> m [(Handshake13, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
            GotPartial GetContinuation (HandshakeType, ByteString)
cont -> do
                (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (\TLSState
st -> TLSState
st{stHandshakeRecordCont13 = (Just cont, bs : wirebytes)})
                [(Handshake13, WireBytes)] -> m [(Handshake13, WireBytes)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
            GotSuccess (HandshakeType
ty, ByteString
content) ->
                case HandshakeType -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [(Handshake13, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake13
h -> [(Handshake13, WireBytes)] -> m [(Handshake13, WireBytes)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [(Handshake13
h, WireBytes -> WireBytes
forall a. [a] -> [a]
reverse (ByteString
bs ByteString -> WireBytes -> WireBytes
forall a. a -> [a] -> [a]
: WireBytes
wirebytes))]
            GotSuccessRemaining (HandshakeType
ty, ByteString
content) ByteString
left ->
                case HandshakeType -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType
ty ByteString
content of
                    Left TLSError
err -> TLSError -> m [(Handshake13, WireBytes)]
forall a. TLSError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                    Right Handshake13
h -> do
                        hbs <- Maybe (GetContinuation (HandshakeType, ByteString))
-> WireBytes -> ByteString -> m [(Handshake13, WireBytes)]
parseMany Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing [] ByteString
left
                        let len = ByteString -> Int
BS.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
left
                            bs' = Int -> ByteString -> ByteString
BS.take Int
len ByteString
bs
                        return ((h, reverse (bs' : wirebytes)) : hbs)
decodePacket13 Context
_ Record Plaintext
_ = Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet13 -> IO (Either TLSError Packet13))
-> Either TLSError Packet13 -> IO (Either TLSError Packet13)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet13
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing String
"unknown protocol type")