{-# LANGUAGE CPP #-}

{- |
Module      : Data.ASN1.Serialize
License     : BSD-style
Copyright   : (c) 2010-2013 Vincent Hanquez <vincent@snarc.org>
Stability   : experimental
Portability : unknown
-}

module Data.ASN1.Serialize
  ( getHeader
  , putHeader
  ) where

import           Control.Monad ( when )
import           Data.ASN1.Get ( Get, getBytes, getWord8 )
import           Data.ASN1.Internal ( bytesOfUInt, putVarEncodingIntegral )
import           Data.ASN1.Types.Lowlevel
                   ( ASN1Class, ASN1Header (..), ASN1Length (..), ASN1Tag )
import           Data.Bits ( (.&.), (.|.), clearBit, shiftL, shiftR, testBit )
import qualified Data.ByteString as B
import           Data.List.NonEmpty ( NonEmpty (..), (<|) )
import qualified Data.List.NonEmpty as NE
import           Data.Word ( Word8 )

-- | Helper function while base < 4.15.0.0 (GHC < 9.0.1) is supported.

singletonNE :: a -> NonEmpty a
#if MIN_VERSION_base(4,15,0)
singletonNE :: forall a. a -> NonEmpty a
singletonNE = a -> NonEmpty a
forall a. a -> NonEmpty a
NE.singleton
#else
singletonNE a = a :| []
#endif

-- | Parse an ASN1 header.

getHeader :: Get ASN1Header
getHeader :: Get ASN1Header
getHeader = do
  (cl, pc, t1) <- Word8 -> (ASN1Class, Bool, Int)
parseFirstWord (Word8 -> (ASN1Class, Bool, Int))
-> Get Word8 -> Get (ASN1Class, Bool, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
  tag <- if t1 == 0x1f then getTagLong else pure t1
  ASN1Header cl tag pc <$> getLength

-- | Parse the first word of an header.

parseFirstWord :: Word8 -> (ASN1Class, Bool, ASN1Tag)
parseFirstWord :: Word8 -> (ASN1Class, Bool, Int)
parseFirstWord Word8
w = (ASN1Class
cl, Bool
pc, Int
t1)
 where
  cl :: ASN1Class
cl = Int -> ASN1Class
forall a. Enum a => Int -> a
toEnum (Int -> ASN1Class) -> Int -> ASN1Class
forall a b. (a -> b) -> a -> b
$ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
w Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`shiftR` Int
6)
  pc :: Bool
pc = Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
w Int
5
  t1 :: Int
t1 = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
w Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x1f)

{- when the first tag is 0x1f, the tag is in long form, where
 - we get bytes while the 7th bit is set. -}
getTagLong :: Get ASN1Tag
getTagLong :: Get Int
getTagLong = do
  t <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
  when (t == 0x80) $ fail "non canonical encoding of long tag"
  if testBit t 7
    then loop (clearBit t 7)
    else pure t
 where
  loop :: b -> Get b
loop b
n = do
    t <- Word8 -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> b) -> Get Word8 -> Get b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    if testBit t 7
      then loop (n `shiftL` 7 + clearBit t 7)
      else pure (n `shiftL` 7 + t)

{- get the asn1 length which is either short form if 7th bit is not set,
 - indefinite form is the 7 bit is set and every other bits clear,
 - or long form otherwise, where the next bytes will represent the length
 -}
getLength :: Get ASN1Length
getLength :: Get ASN1Length
getLength = do
  l1 <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
  if testBit l1 7
    then case clearBit l1 7 of
      Int
0   -> ASN1Length -> Get ASN1Length
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ASN1Length
LenIndefinite
      Int
len -> do
        lw <- Int -> Get ByteString
getBytes Int
len
        pure (LenLong len $ uintbs lw)
    else
      pure (LenShort l1)
 where
  {- uintbs return the unsigned int represented by the bytes -}
  uintbs :: ByteString -> Int
uintbs = (Int -> Word8 -> Int) -> Int -> ByteString -> Int
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
B.foldl (\Int
acc Word8
n -> (Int
acc Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
n) Int
0

-- | putIdentifier encode an ASN1 Identifier into a marshalled value.

putHeader :: ASN1Header -> B.ByteString
putHeader :: ASN1Header -> ByteString
putHeader (ASN1Header ASN1Class
cl Int
tag Bool
pc ASN1Length
len) = [ByteString] -> ByteString
B.concat
  [ Word8 -> ByteString
B.singleton Word8
word1
  , if Int
tag Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0x1f then ByteString
B.empty else ByteString
tagBS
  , ByteString
lenBS
  ]
 where
  cli :: Word8
cli = Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
shiftL (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ASN1Class -> Int
forall a. Enum a => a -> Int
fromEnum ASN1Class
cl) Int
6
  pcval :: Word8
pcval = Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
shiftL (if Bool
pc then Word8
0x1 else Word8
0x0) Int
5
  tag0 :: Word8
tag0 = if Int
tag Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0x1f then Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
tag else Word8
0x1f
  word1 :: Word8
word1 = Word8
cli Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
pcval Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8
tag0
  lenBS :: ByteString
lenBS = [Word8] -> ByteString
B.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ NonEmpty Word8 -> [Word8]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty Word8 -> [Word8]) -> NonEmpty Word8 -> [Word8]
forall a b. (a -> b) -> a -> b
$ ASN1Length -> NonEmpty Word8
putLength ASN1Length
len
  tagBS :: ByteString
tagBS = Int -> ByteString
forall i. (Bits i, Integral i) => i -> ByteString
putVarEncodingIntegral Int
tag

-- | putLength encode a length into a ASN1 length. See 'getLength' for the

-- encoding rules.

putLength :: ASN1Length -> NonEmpty Word8
putLength :: ASN1Length -> NonEmpty Word8
putLength (LenShort Int
i)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0x7f = String -> NonEmpty Word8
forall a. HasCallStack => String -> a
error String
"putLength: short length is not between 0x0 and 0x80"
  | Bool
otherwise         = Word8 -> NonEmpty Word8
forall a. a -> NonEmpty a
singletonNE (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)
putLength (LenLong Int
_ Int
i)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = String -> NonEmpty Word8
forall a. HasCallStack => String -> a
error String
"putLength: long length is negative"
  | Bool
otherwise = Word8
lenbytes Word8 -> NonEmpty Word8 -> NonEmpty Word8
forall a. a -> NonEmpty a -> NonEmpty a
<| NonEmpty Word8
lw
 where
  lw :: NonEmpty Word8
lw = Integer -> NonEmpty Word8
bytesOfUInt (Integer -> NonEmpty Word8) -> Integer -> NonEmpty Word8
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
  lenbytes :: Word8
lenbytes = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (NonEmpty Word8 -> Int
forall a. NonEmpty a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty Word8
lw Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
0x80)
putLength ASN1Length
LenIndefinite = Word8 -> NonEmpty Word8
forall a. a -> NonEmpty a
singletonNE Word8
0x80