1--[[--------------------------------------------------------------------------
2
3    This file is part of lunit 0.5.
4
5    For Details about lunit look at: http://www.mroth.net/lunit/
6
7    Author: Michael Roth <[email protected]>
8
9    Copyright (c) 2004, 2006-2010 Michael Roth <[email protected]>
10
11    Permission is hereby granted, free of charge, to any person
12    obtaining a copy of this software and associated documentation
13    files (the "Software"), to deal in the Software without restriction,
14    including without limitation the rights to use, copy, modify, merge,
15    publish, distribute, sublicense, and/or sell copies of the Software,
16    and to permit persons to whom the Software is furnished to do so,
17    subject to the following conditions:
18
19    The above copyright notice and this permission notice shall be
20    included in all copies or substantial portions of the Software.
21
22    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
25    IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
26    CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
27    TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
28    SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
29
30--]]--------------------------------------------------------------------------
31
32
33local orig_assert     = assert
34
35local pairs           = pairs
36local ipairs          = ipairs
37local next            = next
38local type            = type
39local error           = error
40local tostring        = tostring
41local setmetatable    = setmetatable
42local pcall           = pcall
43local xpcall          = xpcall
44local require         = require
45local loadfile        = loadfile
46
47local string_sub      = string.sub
48local string_gsub     = string.gsub
49local string_format   = string.format
50local string_lower    = string.lower
51local string_find     = string.find
52
53local table_concat    = table.concat
54
55local debug_getinfo   = debug.getinfo
56
57local _G = _G
58
59local lunit
60
61if _VERSION >= 'Lua 5.2' then
62
63    lunit = {}
64    _ENV = lunit
65
66else
67
68    module("lunit")
69    lunit = _M
70
71end
72
73
74local __failure__ = {}    -- Type tag for failed assertions
75
76local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
77
78
79local traceback_hide      -- Traceback function which hides lunit internals
80local mypcall             -- Protected call to a function with own traceback
81do
82  local _tb_hide = setmetatable( {}, {__mode="k"} )
83
84  function traceback_hide(func)
85    _tb_hide[func] = true
86  end
87
88  local function my_traceback(errobj)
89    if is_table(errobj) and errobj.type == __failure__ then
90      local info = debug_getinfo(5, "Sl")   -- FIXME: Hardcoded integers are bad...
91      errobj.where = string_format( "%s:%d", info.short_src, info.currentline)
92    else
93      errobj = { msg = tostring(errobj) }
94      errobj.tb = {}
95      local i = 2
96      while true do
97        local info = debug_getinfo(i, "Snlf")
98        if not is_table(info) then
99          break
100        end
101        if not _tb_hide[info.func] then
102          local line = {}       -- Ripped from ldblib.c...
103          line[#line+1] = string_format("%s:", info.short_src)
104          if info.currentline > 0 then
105            line[#line+1] = string_format("%d:", info.currentline)
106          end
107          if info.namewhat ~= "" then
108            line[#line+1] = string_format(" in function '%s'", info.name)
109          else
110            if info.what == "main" then
111              line[#line+1] = " in main chunk"
112            elseif info.what == "C" or info.what == "tail" then
113              line[#line+1] = " ?"
114            else
115              line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined)
116            end
117          end
118          errobj.tb[#errobj.tb+1] = table_concat(line)
119        end
120        i = i + 1
121      end
122    end
123    return errobj
124  end
125
126  function mypcall(func)
127    orig_assert( is_function(func) )
128    local ok, errobj = xpcall(func, my_traceback)
129    if not ok then
130      return errobj
131    end
132  end
133  traceback_hide(mypcall)
134end
135
136
137-- Type check functions
138
139for _, typename in ipairs(typenames) do
140  lunit["is_"..typename] = function(x)
141    return type(x) == typename
142  end
143end
144
145local is_nil      = is_nil
146local is_boolean  = is_boolean
147local is_number   = is_number
148local is_string   = is_string
149local is_table    = is_table
150local is_function = is_function
151local is_thread   = is_thread
152local is_userdata = is_userdata
153
154
155local function failure(name, usermsg, defaultmsg, ...)
156  local errobj = {
157    type    = __failure__,
158    name    = name,
159    msg     = string_format(defaultmsg,...),
160    usermsg = usermsg
161  }
162  error(errobj, 0)
163end
164traceback_hide( failure )
165
166
167local function format_arg(arg)
168  local argtype = type(arg)
169  if argtype == "string" then
170    return "'"..arg.."'"
171  elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then
172    return tostring(arg)
173  else
174    return "["..tostring(arg).."]"
175  end
176end
177
178
179local function selected(map, name)
180    if not map then
181        return true
182    end
183
184    local m = {}
185    for k,v in pairs(map) do
186        m[k] = lunitpat2luapat(v)
187    end
188    return in_patternmap(m, name)
189end
190
191
192function fail(msg)
193  stats.assertions = stats.assertions + 1
194  failure( "fail", msg, "failure" )
195end
196traceback_hide( fail )
197
198
199function assert(assertion, msg)
200  stats.assertions = stats.assertions + 1
201  if not assertion then
202    failure( "assert", msg, "assertion failed" )
203  end
204  return assertion
205end
206traceback_hide( assert )
207
208
209function assert_true(actual, msg)
210  stats.assertions = stats.assertions + 1
211  if actual ~= true then
212    failure( "assert_true", msg, "true expected but was %s", format_arg(actual) )
213  end
214  return actual
215end
216traceback_hide( assert_true )
217
218
219function assert_false(actual, msg)
220  stats.assertions = stats.assertions + 1
221  if actual ~= false then
222    failure( "assert_false", msg, "false expected but was %s", format_arg(actual) )
223  end
224  return actual
225end
226traceback_hide( assert_false )
227
228
229function assert_equal(expected, actual, msg)
230  stats.assertions = stats.assertions + 1
231  if expected ~= actual then
232    failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) )
233  end
234  return actual
235end
236traceback_hide( assert_equal )
237
238
239function assert_not_equal(unexpected, actual, msg)
240  stats.assertions = stats.assertions + 1
241  if unexpected == actual then
242    failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) )
243  end
244  return actual
245end
246traceback_hide( assert_not_equal )
247
248
249function assert_match(pattern, actual, msg)
250  stats.assertions = stats.assertions + 1
251  if type(pattern) ~= "string" then
252    failure( "assert_match", msg, "expected a string as pattern but was %s", format_arg(pattern) )
253  end
254  if type(actual) ~= "string" then
255    failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, format_arg(actual) )
256  end
257  if not string_find(actual, pattern) then
258    failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern )
259  end
260  return actual
261end
262traceback_hide( assert_match )
263
264
265function assert_not_match(pattern, actual, msg)
266  stats.assertions = stats.assertions + 1
267  if type(pattern) ~= "string" then
268    failure( "assert_not_match", msg, "expected a string as pattern but was %s", format_arg(pattern) )
269  end
270  if type(actual) ~= "string" then
271    failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was %s", pattern, format_arg(actual) )
272  end
273  if string_find(actual, pattern) then
274    failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern )
275  end
276  return actual
277end
278traceback_hide( assert_not_match )
279
280
281function assert_error(msg, func)
282  stats.assertions = stats.assertions + 1
283  if func == nil then
284    func, msg = msg, nil
285  end
286  if type(func) ~= "function" then
287    failure( "assert_error", msg, "expected a function as last argument but was %s", format_arg(func) )
288  end
289  local ok, errmsg = pcall(func)
290  if ok then
291    failure( "assert_error", msg, "error expected but no error occurred" )
292  end
293end
294traceback_hide( assert_error )
295
296
297function assert_error_match(msg, pattern, func)
298  stats.assertions = stats.assertions + 1
299  if func == nil then
300    msg, pattern, func = nil, msg, pattern
301  end
302  if type(pattern) ~= "string" then
303    failure( "assert_error_match", msg, "expected the pattern as a string but was %s", format_arg(pattern) )
304  end
305  if type(func) ~= "function" then
306    failure( "assert_error_match", msg, "expected a function as last argument but was %s", format_arg(func) )
307  end
308  local ok, errmsg = pcall(func)
309  if ok then
310    failure( "assert_error_match", msg, "error expected but no error occurred" )
311  end
312  if type(errmsg) ~= "string" then
313    failure( "assert_error_match", msg, "error as string expected but was %s", format_arg(errmsg) )
314  end
315  if not string_find(errmsg, pattern) then
316    failure( "assert_error_match", msg, "expected error '%s' to match pattern '%s' but doesn't", errmsg, pattern )
317  end
318end
319traceback_hide( assert_error_match )
320
321
322function assert_pass(msg, func)
323  stats.assertions = stats.assertions + 1
324  if func == nil then
325    func, msg = msg, nil
326  end
327  if type(func) ~= "function" then
328    failure( "assert_pass", msg, "expected a function as last argument but was %s", format_arg(func) )
329  end
330  local ok, errmsg = pcall(func)
331  if not ok then
332    failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg )
333  end
334end
335traceback_hide( assert_pass )
336
337
338-- lunit.assert_typename functions
339
340for _, typename in ipairs(typenames) do
341  local assert_typename = "assert_"..typename
342  lunit[assert_typename] = function(actual, msg)
343    stats.assertions = stats.assertions + 1
344    if type(actual) ~= typename then
345      failure( assert_typename, msg, "%s expected but was %s", typename, format_arg(actual) )
346    end
347    return actual
348  end
349  traceback_hide( lunit[assert_typename] )
350end
351
352
353-- lunit.assert_not_typename functions
354
355for _, typename in ipairs(typenames) do
356  local assert_not_typename = "assert_not_"..typename
357  lunit[assert_not_typename] = function(actual, msg)
358    stats.assertions = stats.assertions + 1
359    if type(actual) == typename then
360      failure( assert_not_typename, msg, typename.." not expected but was one" )
361    end
362  end
363  traceback_hide( lunit[assert_not_typename] )
364end
365
366
367function lunit.clearstats()
368  stats = {
369    assertions  = 0;
370    passed      = 0;
371    failed      = 0;
372    errors      = 0;
373  }
374end
375
376
377local report, reporterrobj
378do
379  local testrunner
380
381  function lunit.setrunner(newrunner)
382    if not ( is_table(newrunner) or is_nil(newrunner) ) then
383      return error("lunit.setrunner: Invalid argument", 0)
384    end
385    local oldrunner = testrunner
386    testrunner = newrunner
387    return oldrunner
388  end
389
390  function lunit.loadrunner(name)
391    if not is_string(name) then
392      return error("lunit.loadrunner: Invalid argument", 0)
393    end
394    local ok, runner = pcall( require, name )
395    if not ok then
396      return error("lunit.loadrunner: Can't load test runner: "..runner, 0)
397    end
398    return setrunner(runner)
399  end
400
401  function lunit.getrunner()
402    return testrunner
403  end
404
405  function report(event, ...)
406    local f = testrunner and testrunner[event]
407    if is_function(f) then
408      pcall(f, ...)
409    end
410  end
411
412  function reporterrobj(context, tcname, testname, errobj)
413    local fullname = tcname .. "." .. testname
414    if context == "setup" then
415      fullname = fullname .. ":" .. setupname(tcname, testname)
416    elseif context == "teardown" then
417      fullname = fullname .. ":" .. teardownname(tcname, testname)
418    end
419    if errobj.type == __failure__ then
420      stats.failed = stats.failed + 1
421      report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg)
422    else
423      stats.errors = stats.errors + 1
424      report("err", fullname, errobj.msg, errobj.tb)
425    end
426  end
427end
428
429
430
431local function key_iter(t, k)
432    return (next(t,k))
433end
434
435
436local testcase
437do
438  -- Array with all registered testcases
439  local _testcases = {}
440
441  -- Marks a module as a testcase.
442  -- Applied over a module from module("xyz", lunit.testcase).
443  function lunit.testcase(m)
444    orig_assert( is_table(m) )
445    --orig_assert( m._M == m )
446    orig_assert( is_string(m._NAME) )
447    --orig_assert( is_string(m._PACKAGE) )
448
449    -- Register the module as a testcase
450    _testcases[m._NAME] = m
451
452    -- Import lunit, fail, assert* and is_* function to the module/testcase
453    m.lunit = lunit
454    m.fail = lunit.fail
455    for funcname, func in pairs(lunit) do
456      if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then
457        m[funcname] = func
458      end
459    end
460  end
461
462  function lunit.module(name,seeall)
463    local m = {}
464    if seeall == "seeall" then
465      setmetatable(m, { __index = _G })
466    end
467    m._NAME = name
468    lunit.testcase(m)
469    return m
470  end
471
472  -- Iterator (testcasename) over all Testcases
473  function lunit.testcases()
474    -- Make a copy of testcases to prevent confusing the iterator when
475    -- new testcase are defined
476    local _testcases2 = {}
477    for k,v in pairs(_testcases) do
478        _testcases2[k] = true
479    end
480    return key_iter, _testcases2, nil
481  end
482
483  function testcase(tcname)
484    return _testcases[tcname]
485  end
486end
487
488
489do
490  -- Finds a function in a testcase case insensitive
491  local function findfuncname(tcname, name)
492    for key, value in pairs(testcase(tcname)) do
493      if is_string(key) and is_function(value) and string_lower(key) == name then
494        return key
495      end
496    end
497  end
498
499  function lunit.setupname(tcname)
500    return findfuncname(tcname, "setup")
501  end
502
503  function lunit.teardownname(tcname)
504    return findfuncname(tcname, "teardown")
505  end
506
507  -- Iterator over all test names in a testcase.
508  -- Have to collect the names first in case one of the test
509  -- functions creates a new global and throws off the iteration.
510  function lunit.tests(tcname)
511    local testnames = {}
512    for key, value in pairs(testcase(tcname)) do
513      if is_string(key) and is_function(value) then
514        local lfn = string_lower(key)
515        if string_sub(lfn, 1, 4) == "test" or string_sub(lfn, -4) == "test" then
516          testnames[key] = true
517        end
518      end
519    end
520    return key_iter, testnames, nil
521  end
522end
523
524
525
526
527function lunit.runtest(tcname, testname)
528  orig_assert( is_string(tcname) )
529  orig_assert( is_string(testname) )
530
531  if (not getrunner()) then
532    loadrunner("console")
533  end
534
535  local function callit(context, func)
536    if func then
537      local err = mypcall(func)
538      if err then
539        reporterrobj(context, tcname, testname, err)
540        return false
541      end
542    end
543    return true
544  end
545  traceback_hide(callit)
546
547  report("run", tcname, testname)
548
549  local tc          = testcase(tcname)
550  local setup       = tc[setupname(tcname)]
551  local test        = tc[testname]
552  local teardown    = tc[teardownname(tcname)]
553
554  local setup_ok    =              callit( "setup", setup )
555  local test_ok     = setup_ok and callit( "test", test )
556  local teardown_ok = setup_ok and callit( "teardown", teardown )
557
558  if setup_ok and test_ok and teardown_ok then
559    stats.passed = stats.passed + 1
560    report("pass", tcname, testname)
561  end
562end
563traceback_hide(runtest)
564
565
566
567function lunit.run(testpatterns)
568  clearstats()
569  report("begin")
570  for testcasename in lunit.testcases() do
571    -- Run tests in the testcases
572    for testname in lunit.tests(testcasename) do
573      if selected(testpatterns, testname) then
574        runtest(testcasename, testname)
575      end
576    end
577  end
578  report("done")
579  return stats
580end
581traceback_hide(run)
582
583
584function lunit.loadonly()
585  clearstats()
586  report("begin")
587  report("done")
588  return stats
589end
590
591
592
593
594
595
596
597
598
599local lunitpat2luapat
600do
601  local conv = {
602    ["^"] = "%^",
603    ["$"] = "%$",
604    ["("] = "%(",
605    [")"] = "%)",
606    ["%"] = "%%",
607    ["."] = "%.",
608    ["["] = "%[",
609    ["]"] = "%]",
610    ["+"] = "%+",
611    ["-"] = "%-",
612    ["?"] = ".",
613    ["*"] = ".*"
614  }
615  function lunitpat2luapat(str)
616    --return "^" .. string.gsub(str, "%W", conv) .. "$"
617    -- Above was very annoying, if I want to run all the tests having to do with
618    -- RSS, I want to be able to do "-t rss"   not "-t \*rss\*".
619    return string_gsub(str, "%W", conv)
620  end
621end
622
623
624
625local function in_patternmap(map, name)
626  if map[name] == true then
627    return true
628  else
629    for _, pat in ipairs(map) do
630      if string_find(name, pat) then
631        return true
632      end
633    end
634  end
635  return false
636end
637
638
639
640
641
642
643
644
645-- Called from 'lunit' shell script.
646
647function main(argv)
648  argv = argv or {}
649
650  -- FIXME: Error handling and error messages aren't nice.
651
652  local function checkarg(optname, arg)
653    if not is_string(arg) then
654      return error("lunit.main: option "..optname..": argument missing.", 0)
655    end
656  end
657
658  local function loadtestcase(filename)
659    if not is_string(filename) then
660      return error("lunit.main: invalid argument")
661    end
662    local chunk, err = loadfile(filename)
663    if err then
664      return error(err)
665    else
666      chunk()
667    end
668  end
669
670  local testpatterns = nil
671  local doloadonly = false
672
673  local i = 0
674  while i < #argv do
675    i = i + 1
676    local arg = argv[i]
677    if arg == "--loadonly" then
678      doloadonly = true
679    elseif arg == "--runner" or arg == "-r" then
680      local optname = arg; i = i + 1; arg = argv[i]
681      checkarg(optname, arg)
682      loadrunner(arg)
683    elseif arg == "--test" or arg == "-t" then
684      local optname = arg; i = i + 1; arg = argv[i]
685      checkarg(optname, arg)
686      testpatterns = testpatterns or {}
687      testpatterns[#testpatterns+1] = arg
688    elseif arg == "--help" or arg == "-h" then
689        print[[
690lunit 0.5
691Copyright (c) 2004-2009 Michael Roth <[email protected]>
692This program comes WITHOUT WARRANTY OF ANY KIND.
693
694Usage: lua test [OPTIONS] [--] scripts
695
696Options:
697
698  -r, --runner RUNNER         Testrunner to use, defaults to 'lunit-console'.
699  -t, --test PATTERN          Which tests to run, may contain * or ? wildcards.
700      --loadonly              Only load the tests.
701  -h, --help                  Print this help screen.
702
703Please report bugs to <[email protected]>.
704]]
705        return
706    elseif arg == "--" then
707      while i < #argv do
708        i = i + 1; arg = argv[i]
709        loadtestcase(arg)
710      end
711    else
712      loadtestcase(arg)
713    end
714  end
715
716  if doloadonly then
717    return loadonly()
718  else
719    return run(testpatterns)
720  end
721end
722
723clearstats()
724
725return lunit
726