OTR plugin for Gajim 1.0+
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

proto.py 12KB


  1. # Copyright 2011-2012 Kjell Braden <afflux@pentabarf.de>
  2. #
  3. # This file is part of the python-potr library.
  4. #
  5. # python-potr is free software; you can redistribute it and/or modify
  6. # it under the terms of the GNU Lesser General Public License as published by
  7. # the Free Software Foundation; either version 3 of the License, or
  8. # any later version.
  9. #
  10. # python-potr is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. # GNU Lesser General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU Lesser General Public License
  16. # along with this library. If not, see <http://www.gnu.org/licenses/>.
  17. # some python3 compatibilty
  18. from __future__ import unicode_literals
  19. import base64
  20. import struct
  21. from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
  22. OTRTAG = b'?OTR'
  23. MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
  24. MESSAGE_TAGS = {
  25. 1:b' \t \t \t ',
  26. 2:b' \t\t \t ',
  27. 3:b' \t\t \t\t',
  28. }
  29. MSGTYPE_NOTOTR = 0
  30. MSGTYPE_TAGGEDPLAINTEXT = 1
  31. MSGTYPE_QUERY = 2
  32. MSGTYPE_DH_COMMIT = 3
  33. MSGTYPE_DH_KEY = 4
  34. MSGTYPE_REVEALSIG = 5
  35. MSGTYPE_SIGNATURE = 6
  36. MSGTYPE_V1_KEYEXCH = 7
  37. MSGTYPE_DATA = 8
  38. MSGTYPE_ERROR = 9
  39. MSGTYPE_UNKNOWN = -1
  40. MSGFLAGS_IGNORE_UNREADABLE = 1
  41. tlvClasses = {}
  42. messageClasses = {}
  43. hasByteStr = bytes == str
  44. def bytesAndStrings(cls):
  45. if hasByteStr:
  46. cls.__str__ = lambda self: self.__bytes__()
  47. else:
  48. cls.__str__ = lambda self: str(self.__bytes__(), 'utf-8', 'replace')
  49. return cls
  50. def registermessage(cls):
  51. if not hasattr(cls, 'parsePayload'):
  52. raise TypeError('registered message types need parsePayload()')
  53. messageClasses[cls.version, cls.msgtype] = cls
  54. return cls
  55. def registertlv(cls):
  56. if not hasattr(cls, 'parsePayload'):
  57. raise TypeError('registered tlv types need parsePayload()')
  58. if cls.typ is None:
  59. raise TypeError('registered tlv type needs type ID')
  60. tlvClasses[cls.typ] = cls
  61. return cls
  62. def getslots(cls, base):
  63. ''' helper to collect all the message slots from ancestors '''
  64. clss = [cls]
  65. for cls in clss:
  66. if cls == base:
  67. continue
  68. clss.extend(cls.__bases__)
  69. for slot in cls.__slots__:
  70. yield slot
  71. @bytesAndStrings
  72. class OTRMessage(object):
  73. __slots__ = ['payload']
  74. version = 0x0002
  75. msgtype = 0
  76. def __eq__(self, other):
  77. if not isinstance(other, self.__class__):
  78. return False
  79. for slot in getslots(self.__class__, OTRMessage):
  80. if getattr(self, slot) != getattr(other, slot):
  81. return False
  82. return True
  83. def __neq__(self, other):
  84. return not self.__eq__(other)
  85. class Error(OTRMessage):
  86. __slots__ = ['error']
  87. def __init__(self, error):
  88. super(Error, self).__init__()
  89. self.error = error
  90. def __repr__(self):
  91. return '<proto.Error(%r)>' % self.error
  92. def __bytes__(self):
  93. return b'?OTR Error:' + self.error
  94. class Query(OTRMessage):
  95. __slots__ = ['versions']
  96. def __init__(self, versions=set()):
  97. super(Query, self).__init__()
  98. self.versions = versions
  99. @classmethod
  100. def parse(cls, data):
  101. if not isinstance(data, bytes):
  102. raise TypeError('can only parse bytes')
  103. udata = data.decode('ascii', 'replace')
  104. versions = set()
  105. if len(udata) > 0 and udata[0] == '?':
  106. udata = udata[1:]
  107. versions.add(1)
  108. if len(udata) > 0 and udata[0] == 'v':
  109. versions.update(( int(c) for c in udata if c.isdigit() ))
  110. return cls(versions)
  111. def __repr__(self):
  112. return '<proto.Query(versions=%r)>' % (self.versions)
  113. def __bytes__(self):
  114. d = b'?OTR'
  115. if 1 in self.versions:
  116. d += b'?'
  117. d += b'v'
  118. # in python3 there is only int->unicode conversion
  119. # so I convert to unicode and encode it to a byte string
  120. versions = [ '%d' % v for v in self.versions if v != 1 ]
  121. d += ''.join(versions).encode('ascii')
  122. d += b'?'
  123. return d
  124. class TaggedPlaintext(Query):
  125. __slots__ = ['msg']
  126. def __init__(self, msg, versions):
  127. super(TaggedPlaintext, self).__init__(versions)
  128. self.msg = msg
  129. def __bytes__(self):
  130. data = self.msg + MESSAGE_TAG_BASE
  131. for v in self.versions:
  132. data += MESSAGE_TAGS[v]
  133. return data
  134. def __repr__(self):
  135. return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
  136. .format(versions=self.versions, msg=self.msg)
  137. @classmethod
  138. def parse(cls, data):
  139. tagPos = data.find(MESSAGE_TAG_BASE)
  140. if tagPos < 0:
  141. raise TypeError(
  142. 'this is not a tagged plaintext ({0!r:.20})'.format(data))
  143. tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
  144. versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
  145. in tags ])
  146. return TaggedPlaintext(data[:tagPos], versions)
  147. class GenericOTRMessage(OTRMessage):
  148. __slots__ = ['data']
  149. fields = []
  150. def __init__(self, *args):
  151. super(GenericOTRMessage, self).__init__()
  152. if len(args) != len(self.fields):
  153. raise TypeError('%s needs %d arguments, got %d' %
  154. (self.__class__.__name__, len(self.fields), len(args)))
  155. super(GenericOTRMessage, self).__setattr__('data',
  156. dict(zip((f[0] for f in self.fields), args)))
  157. def __getattr__(self, attr):
  158. if attr in self.data:
  159. return self.data[attr]
  160. raise AttributeError(
  161. "'{t!r}' object has no attribute '{attr!r}'".format(attr=attr,
  162. t=self.__class__.__name__))
  163. def __setattr__(self, attr, val):
  164. if attr in self.__slots__:
  165. super(GenericOTRMessage, self).__setattr__(attr, val)
  166. else:
  167. self.__getattr__(attr) # existence check
  168. self.data[attr] = val
  169. def __bytes__(self):
  170. data = struct.pack(b'!HB', self.version, self.msgtype) \
  171. + self.getPayload()
  172. return b'?OTR:' + base64.b64encode(data) + b'.'
  173. def __repr__(self):
  174. name = self.__class__.__name__
  175. data = ''
  176. for k, _ in self.fields:
  177. data += '%s=%r,' % (k, self.data[k])
  178. return '<proto.%s(%s)>' % (name, data)
  179. @classmethod
  180. def parsePayload(cls, data):
  181. data = base64.b64decode(data)
  182. args = []
  183. for _, ftype in cls.fields:
  184. if ftype == 'data':
  185. value, data = read_data(data)
  186. elif isinstance(ftype, bytes):
  187. value, data = unpack(ftype, data)
  188. elif isinstance(ftype, int):
  189. value, data = data[:ftype], data[ftype:]
  190. args.append(value)
  191. return cls(*args)
  192. def getPayload(self, *ffilter):
  193. payload = b''
  194. for k, ftype in self.fields:
  195. if k in ffilter:
  196. continue
  197. if ftype == 'data':
  198. payload += pack_data(self.data[k])
  199. elif isinstance(ftype, bytes):
  200. payload += struct.pack(ftype, self.data[k])
  201. else:
  202. payload += self.data[k]
  203. return payload
  204. class AKEMessage(GenericOTRMessage):
  205. __slots__ = []
  206. @registermessage
  207. class DHCommit(AKEMessage):
  208. __slots__ = []
  209. msgtype = 0x02
  210. fields = [('encgx', 'data'), ('hashgx', 'data'), ]
  211. @registermessage
  212. class DHKey(AKEMessage):
  213. __slots__ = []
  214. msgtype = 0x0a
  215. fields = [('gy', 'data'), ]
  216. @registermessage
  217. class RevealSig(AKEMessage):
  218. __slots__ = []
  219. msgtype = 0x11
  220. fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
  221. def getMacedData(self):
  222. p = self.encsig
  223. return struct.pack(b'!I', len(p)) + p
  224. @registermessage
  225. class Signature(AKEMessage):
  226. __slots__ = []
  227. msgtype = 0x12
  228. fields = [('encsig', 'data'), ('mac', 20)]
  229. def getMacedData(self):
  230. p = self.encsig
  231. return struct.pack(b'!I', len(p)) + p
  232. @registermessage
  233. class DataMessage(GenericOTRMessage):
  234. __slots__ = []
  235. msgtype = 0x03
  236. fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
  237. ('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
  238. ('oldmacs', 'data'), ]
  239. def getMacedData(self):
  240. return struct.pack(b'!HB', self.version, self.msgtype) + \
  241. self.getPayload('mac', 'oldmacs')
  242. @bytesAndStrings
  243. class TLV(object):
  244. __slots__ = []
  245. typ = None
  246. def getPayload(self):
  247. raise NotImplementedError
  248. def __repr__(self):
  249. val = self.getPayload()
  250. return '<{cls}(typ={t},len={l},val={v!r})>'.format(t=self.typ,
  251. l=len(val), v=val, cls=self.__class__.__name__)
  252. def __bytes__(self):
  253. val = self.getPayload()
  254. return struct.pack(b'!HH', self.typ, len(val)) + val
  255. @classmethod
  256. def parse(cls, data):
  257. if not data:
  258. return []
  259. typ, length, data = unpack(b'!HH', data)
  260. if typ in tlvClasses:
  261. return [tlvClasses[typ].parsePayload(data[:length])] \
  262. + cls.parse(data[length:])
  263. else:
  264. raise UnknownTLV(data)
  265. def __eq__(self, other):
  266. if not isinstance(other, self.__class__):
  267. return False
  268. for slot in getslots(self.__class__, TLV):
  269. if getattr(self, slot) != getattr(other, slot):
  270. return False
  271. return True
  272. def __neq__(self, other):
  273. return not self.__eq__(other)
  274. @registertlv
  275. class PaddingTLV(TLV):
  276. typ = 0
  277. __slots__ = ['padding']
  278. def __init__(self, padding):
  279. super(PaddingTLV, self).__init__()
  280. self.padding = padding
  281. def getPayload(self):
  282. return self.padding
  283. @classmethod
  284. def parsePayload(cls, data):
  285. return cls(data)
  286. @registertlv
  287. class DisconnectTLV(TLV):
  288. typ = 1
  289. def __init__(self):
  290. super(DisconnectTLV, self).__init__()
  291. def getPayload(self):
  292. return b''
  293. @classmethod
  294. def parsePayload(cls, data):
  295. if len(data) > 0:
  296. raise TypeError('DisconnectTLV must not contain data. got {0!r}'
  297. .format(data))
  298. return cls()
  299. class SMPTLV(TLV):
  300. __slots__ = ['mpis']
  301. dlen = None
  302. def __init__(self, mpis=None):
  303. super(SMPTLV, self).__init__()
  304. if mpis is None:
  305. mpis = []
  306. if self.dlen is None:
  307. raise TypeError('no amount of mpis specified in dlen')
  308. if len(mpis) != self.dlen:
  309. raise TypeError('expected {0} mpis, got {1}'
  310. .format(self.dlen, len(mpis)))
  311. self.mpis = mpis
  312. def getPayload(self):
  313. d = struct.pack(b'!I', len(self.mpis))
  314. for n in self.mpis:
  315. d += pack_mpi(n)
  316. return d
  317. @classmethod
  318. def parsePayload(cls, data):
  319. mpis = []
  320. if cls.dlen > 0:
  321. count, data = unpack(b'!I', data)
  322. for _ in range(count):
  323. n, data = read_mpi(data)
  324. mpis.append(n)
  325. if len(data) > 0:
  326. raise TypeError('too much data for {0} mpis'.format(cls.dlen))
  327. return cls(mpis)
  328. @registertlv
  329. class SMP1TLV(SMPTLV):
  330. typ = 2
  331. dlen = 6
  332. @registertlv
  333. class SMP1QTLV(SMPTLV):
  334. typ = 7
  335. dlen = 6
  336. __slots__ = ['msg']
  337. def __init__(self, msg, mpis):
  338. self.msg = msg
  339. super(SMP1QTLV, self).__init__(mpis)
  340. def getPayload(self):
  341. return self.msg + b'\0' + super(SMP1QTLV, self).getPayload()
  342. @classmethod
  343. def parsePayload(cls, data):
  344. msg, data = data.split(b'\0', 1)
  345. mpis = SMP1TLV.parsePayload(data).mpis
  346. return cls(msg, mpis)
  347. @registertlv
  348. class SMP2TLV(SMPTLV):
  349. typ = 3
  350. dlen = 11
  351. @registertlv
  352. class SMP3TLV(SMPTLV):
  353. typ = 4
  354. dlen = 8
  355. @registertlv
  356. class SMP4TLV(SMPTLV):
  357. typ = 5
  358. dlen = 3
  359. @registertlv
  360. class SMPABORTTLV(SMPTLV):
  361. typ = 6
  362. dlen = 0
  363. def getPayload(self):
  364. return b''
  365. @registertlv
  366. class ExtraKeyTLV(TLV):
  367. typ = 8
  368. __slots__ = ['appid', 'appdata']
  369. def __init__(self, appid, appdata):
  370. super(ExtraKeyTLV, self).__init__()
  371. self.appid = appid
  372. self.appdata = appdata
  373. if appdata is None:
  374. self.appdata = b''
  375. def getPayload(self):
  376. return self.appid + self.appdata
  377. @classmethod
  378. def parsePayload(cls, data):
  379. return cls(data[:4], data[4:])
  380. class UnknownTLV(RuntimeError):
  381. pass