1--[[-------------------------------------------------------------------------- 2 3 This file is part of lunit 0.5. 4 5 For Details about lunit look at: http://www.mroth.net/lunit/ 6 7 Author: Michael Roth <[email protected]> 8 9 Copyright (c) 2004, 2006-2010 Michael Roth <[email protected]> 10 11 Permission is hereby granted, free of charge, to any person 12 obtaining a copy of this software and associated documentation 13 files (the "Software"), to deal in the Software without restriction, 14 including without limitation the rights to use, copy, modify, merge, 15 publish, distribute, sublicense, and/or sell copies of the Software, 16 and to permit persons to whom the Software is furnished to do so, 17 subject to the following conditions: 18 19 The above copyright notice and this permission notice shall be 20 included in all copies or substantial portions of the Software. 21 22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 23 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 24 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 25 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 26 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 27 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 28 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 29 30--]]-------------------------------------------------------------------------- 31 32 33local orig_assert = assert 34 35local pairs = pairs 36local ipairs = ipairs 37local next = next 38local type = type 39local error = error 40local tostring = tostring 41local setmetatable = setmetatable 42local pcall = pcall 43local xpcall = xpcall 44local require = require 45local loadfile = loadfile 46 47local string_sub = string.sub 48local string_gsub = string.gsub 49local string_format = string.format 50local string_lower = string.lower 51local string_find = string.find 52 53local table_concat = table.concat 54 55local debug_getinfo = debug.getinfo 56 57local _G = _G 58 59local lunit 60 61if _VERSION >= 'Lua 5.2' then 62 63 lunit = {} 64 _ENV = lunit 65 66else 67 68 module("lunit") 69 lunit = _M 70 71end 72 73 74local __failure__ = {} -- Type tag for failed assertions 75 76local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" } 77 78 79local traceback_hide -- Traceback function which hides lunit internals 80local mypcall -- Protected call to a function with own traceback 81do 82 local _tb_hide = setmetatable( {}, {__mode="k"} ) 83 84 function traceback_hide(func) 85 _tb_hide[func] = true 86 end 87 88 local function my_traceback(errobj) 89 if is_table(errobj) and errobj.type == __failure__ then 90 local info = debug_getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad... 91 errobj.where = string_format( "%s:%d", info.short_src, info.currentline) 92 else 93 errobj = { msg = tostring(errobj) } 94 errobj.tb = {} 95 local i = 2 96 while true do 97 local info = debug_getinfo(i, "Snlf") 98 if not is_table(info) then 99 break 100 end 101 if not _tb_hide[info.func] then 102 local line = {} -- Ripped from ldblib.c... 103 line[#line+1] = string_format("%s:", info.short_src) 104 if info.currentline > 0 then 105 line[#line+1] = string_format("%d:", info.currentline) 106 end 107 if info.namewhat ~= "" then 108 line[#line+1] = string_format(" in function '%s'", info.name) 109 else 110 if info.what == "main" then 111 line[#line+1] = " in main chunk" 112 elseif info.what == "C" or info.what == "tail" then 113 line[#line+1] = " ?" 114 else 115 line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined) 116 end 117 end 118 errobj.tb[#errobj.tb+1] = table_concat(line) 119 end 120 i = i + 1 121 end 122 end 123 return errobj 124 end 125 126 function mypcall(func) 127 orig_assert( is_function(func) ) 128 local ok, errobj = xpcall(func, my_traceback) 129 if not ok then 130 return errobj 131 end 132 end 133 traceback_hide(mypcall) 134end 135 136 137-- Type check functions 138 139for _, typename in ipairs(typenames) do 140 lunit["is_"..typename] = function(x) 141 return type(x) == typename 142 end 143end 144 145local is_nil = is_nil 146local is_boolean = is_boolean 147local is_number = is_number 148local is_string = is_string 149local is_table = is_table 150local is_function = is_function 151local is_thread = is_thread 152local is_userdata = is_userdata 153 154 155local function failure(name, usermsg, defaultmsg, ...) 156 local errobj = { 157 type = __failure__, 158 name = name, 159 msg = string_format(defaultmsg,...), 160 usermsg = usermsg 161 } 162 error(errobj, 0) 163end 164traceback_hide( failure ) 165 166 167local function format_arg(arg) 168 local argtype = type(arg) 169 if argtype == "string" then 170 return "'"..arg.."'" 171 elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then 172 return tostring(arg) 173 else 174 return "["..tostring(arg).."]" 175 end 176end 177 178 179local function selected(map, name) 180 if not map then 181 return true 182 end 183 184 local m = {} 185 for k,v in pairs(map) do 186 m[k] = lunitpat2luapat(v) 187 end 188 return in_patternmap(m, name) 189end 190 191 192function fail(msg) 193 stats.assertions = stats.assertions + 1 194 failure( "fail", msg, "failure" ) 195end 196traceback_hide( fail ) 197 198 199function assert(assertion, msg) 200 stats.assertions = stats.assertions + 1 201 if not assertion then 202 failure( "assert", msg, "assertion failed" ) 203 end 204 return assertion 205end 206traceback_hide( assert ) 207 208 209function assert_true(actual, msg) 210 stats.assertions = stats.assertions + 1 211 if actual ~= true then 212 failure( "assert_true", msg, "true expected but was %s", format_arg(actual) ) 213 end 214 return actual 215end 216traceback_hide( assert_true ) 217 218 219function assert_false(actual, msg) 220 stats.assertions = stats.assertions + 1 221 if actual ~= false then 222 failure( "assert_false", msg, "false expected but was %s", format_arg(actual) ) 223 end 224 return actual 225end 226traceback_hide( assert_false ) 227 228 229function assert_equal(expected, actual, msg) 230 stats.assertions = stats.assertions + 1 231 if expected ~= actual then 232 failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) ) 233 end 234 return actual 235end 236traceback_hide( assert_equal ) 237 238 239function assert_not_equal(unexpected, actual, msg) 240 stats.assertions = stats.assertions + 1 241 if unexpected == actual then 242 failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) ) 243 end 244 return actual 245end 246traceback_hide( assert_not_equal ) 247 248 249function assert_match(pattern, actual, msg) 250 stats.assertions = stats.assertions + 1 251 if type(pattern) ~= "string" then 252 failure( "assert_match", msg, "expected a string as pattern but was %s", format_arg(pattern) ) 253 end 254 if type(actual) ~= "string" then 255 failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, format_arg(actual) ) 256 end 257 if not string_find(actual, pattern) then 258 failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern ) 259 end 260 return actual 261end 262traceback_hide( assert_match ) 263 264 265function assert_not_match(pattern, actual, msg) 266 stats.assertions = stats.assertions + 1 267 if type(pattern) ~= "string" then 268 failure( "assert_not_match", msg, "expected a string as pattern but was %s", format_arg(pattern) ) 269 end 270 if type(actual) ~= "string" then 271 failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was %s", pattern, format_arg(actual) ) 272 end 273 if string_find(actual, pattern) then 274 failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern ) 275 end 276 return actual 277end 278traceback_hide( assert_not_match ) 279 280 281function assert_error(msg, func) 282 stats.assertions = stats.assertions + 1 283 if func == nil then 284 func, msg = msg, nil 285 end 286 if type(func) ~= "function" then 287 failure( "assert_error", msg, "expected a function as last argument but was %s", format_arg(func) ) 288 end 289 local ok, errmsg = pcall(func) 290 if ok then 291 failure( "assert_error", msg, "error expected but no error occurred" ) 292 end 293end 294traceback_hide( assert_error ) 295 296 297function assert_error_match(msg, pattern, func) 298 stats.assertions = stats.assertions + 1 299 if func == nil then 300 msg, pattern, func = nil, msg, pattern 301 end 302 if type(pattern) ~= "string" then 303 failure( "assert_error_match", msg, "expected the pattern as a string but was %s", format_arg(pattern) ) 304 end 305 if type(func) ~= "function" then 306 failure( "assert_error_match", msg, "expected a function as last argument but was %s", format_arg(func) ) 307 end 308 local ok, errmsg = pcall(func) 309 if ok then 310 failure( "assert_error_match", msg, "error expected but no error occurred" ) 311 end 312 if type(errmsg) ~= "string" then 313 failure( "assert_error_match", msg, "error as string expected but was %s", format_arg(errmsg) ) 314 end 315 if not string_find(errmsg, pattern) then 316 failure( "assert_error_match", msg, "expected error '%s' to match pattern '%s' but doesn't", errmsg, pattern ) 317 end 318end 319traceback_hide( assert_error_match ) 320 321 322function assert_pass(msg, func) 323 stats.assertions = stats.assertions + 1 324 if func == nil then 325 func, msg = msg, nil 326 end 327 if type(func) ~= "function" then 328 failure( "assert_pass", msg, "expected a function as last argument but was %s", format_arg(func) ) 329 end 330 local ok, errmsg = pcall(func) 331 if not ok then 332 failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg ) 333 end 334end 335traceback_hide( assert_pass ) 336 337 338-- lunit.assert_typename functions 339 340for _, typename in ipairs(typenames) do 341 local assert_typename = "assert_"..typename 342 lunit[assert_typename] = function(actual, msg) 343 stats.assertions = stats.assertions + 1 344 if type(actual) ~= typename then 345 failure( assert_typename, msg, "%s expected but was %s", typename, format_arg(actual) ) 346 end 347 return actual 348 end 349 traceback_hide( lunit[assert_typename] ) 350end 351 352 353-- lunit.assert_not_typename functions 354 355for _, typename in ipairs(typenames) do 356 local assert_not_typename = "assert_not_"..typename 357 lunit[assert_not_typename] = function(actual, msg) 358 stats.assertions = stats.assertions + 1 359 if type(actual) == typename then 360 failure( assert_not_typename, msg, typename.." not expected but was one" ) 361 end 362 end 363 traceback_hide( lunit[assert_not_typename] ) 364end 365 366 367function lunit.clearstats() 368 stats = { 369 assertions = 0; 370 passed = 0; 371 failed = 0; 372 errors = 0; 373 } 374end 375 376 377local report, reporterrobj 378do 379 local testrunner 380 381 function lunit.setrunner(newrunner) 382 if not ( is_table(newrunner) or is_nil(newrunner) ) then 383 return error("lunit.setrunner: Invalid argument", 0) 384 end 385 local oldrunner = testrunner 386 testrunner = newrunner 387 return oldrunner 388 end 389 390 function lunit.loadrunner(name) 391 if not is_string(name) then 392 return error("lunit.loadrunner: Invalid argument", 0) 393 end 394 local ok, runner = pcall( require, name ) 395 if not ok then 396 return error("lunit.loadrunner: Can't load test runner: "..runner, 0) 397 end 398 return setrunner(runner) 399 end 400 401 function lunit.getrunner() 402 return testrunner 403 end 404 405 function report(event, ...) 406 local f = testrunner and testrunner[event] 407 if is_function(f) then 408 pcall(f, ...) 409 end 410 end 411 412 function reporterrobj(context, tcname, testname, errobj) 413 local fullname = tcname .. "." .. testname 414 if context == "setup" then 415 fullname = fullname .. ":" .. setupname(tcname, testname) 416 elseif context == "teardown" then 417 fullname = fullname .. ":" .. teardownname(tcname, testname) 418 end 419 if errobj.type == __failure__ then 420 stats.failed = stats.failed + 1 421 report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg) 422 else 423 stats.errors = stats.errors + 1 424 report("err", fullname, errobj.msg, errobj.tb) 425 end 426 end 427end 428 429 430 431local function key_iter(t, k) 432 return (next(t,k)) 433end 434 435 436local testcase 437do 438 -- Array with all registered testcases 439 local _testcases = {} 440 441 -- Marks a module as a testcase. 442 -- Applied over a module from module("xyz", lunit.testcase). 443 function lunit.testcase(m) 444 orig_assert( is_table(m) ) 445 --orig_assert( m._M == m ) 446 orig_assert( is_string(m._NAME) ) 447 --orig_assert( is_string(m._PACKAGE) ) 448 449 -- Register the module as a testcase 450 _testcases[m._NAME] = m 451 452 -- Import lunit, fail, assert* and is_* function to the module/testcase 453 m.lunit = lunit 454 m.fail = lunit.fail 455 for funcname, func in pairs(lunit) do 456 if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then 457 m[funcname] = func 458 end 459 end 460 end 461 462 function lunit.module(name,seeall) 463 local m = {} 464 if seeall == "seeall" then 465 setmetatable(m, { __index = _G }) 466 end 467 m._NAME = name 468 lunit.testcase(m) 469 return m 470 end 471 472 -- Iterator (testcasename) over all Testcases 473 function lunit.testcases() 474 -- Make a copy of testcases to prevent confusing the iterator when 475 -- new testcase are defined 476 local _testcases2 = {} 477 for k,v in pairs(_testcases) do 478 _testcases2[k] = true 479 end 480 return key_iter, _testcases2, nil 481 end 482 483 function testcase(tcname) 484 return _testcases[tcname] 485 end 486end 487 488 489do 490 -- Finds a function in a testcase case insensitive 491 local function findfuncname(tcname, name) 492 for key, value in pairs(testcase(tcname)) do 493 if is_string(key) and is_function(value) and string_lower(key) == name then 494 return key 495 end 496 end 497 end 498 499 function lunit.setupname(tcname) 500 return findfuncname(tcname, "setup") 501 end 502 503 function lunit.teardownname(tcname) 504 return findfuncname(tcname, "teardown") 505 end 506 507 -- Iterator over all test names in a testcase. 508 -- Have to collect the names first in case one of the test 509 -- functions creates a new global and throws off the iteration. 510 function lunit.tests(tcname) 511 local testnames = {} 512 for key, value in pairs(testcase(tcname)) do 513 if is_string(key) and is_function(value) then 514 local lfn = string_lower(key) 515 if string_sub(lfn, 1, 4) == "test" or string_sub(lfn, -4) == "test" then 516 testnames[key] = true 517 end 518 end 519 end 520 return key_iter, testnames, nil 521 end 522end 523 524 525 526 527function lunit.runtest(tcname, testname) 528 orig_assert( is_string(tcname) ) 529 orig_assert( is_string(testname) ) 530 531 if (not getrunner()) then 532 loadrunner("console") 533 end 534 535 local function callit(context, func) 536 if func then 537 local err = mypcall(func) 538 if err then 539 reporterrobj(context, tcname, testname, err) 540 return false 541 end 542 end 543 return true 544 end 545 traceback_hide(callit) 546 547 report("run", tcname, testname) 548 549 local tc = testcase(tcname) 550 local setup = tc[setupname(tcname)] 551 local test = tc[testname] 552 local teardown = tc[teardownname(tcname)] 553 554 local setup_ok = callit( "setup", setup ) 555 local test_ok = setup_ok and callit( "test", test ) 556 local teardown_ok = setup_ok and callit( "teardown", teardown ) 557 558 if setup_ok and test_ok and teardown_ok then 559 stats.passed = stats.passed + 1 560 report("pass", tcname, testname) 561 end 562end 563traceback_hide(runtest) 564 565 566 567function lunit.run(testpatterns) 568 clearstats() 569 report("begin") 570 for testcasename in lunit.testcases() do 571 -- Run tests in the testcases 572 for testname in lunit.tests(testcasename) do 573 if selected(testpatterns, testname) then 574 runtest(testcasename, testname) 575 end 576 end 577 end 578 report("done") 579 return stats 580end 581traceback_hide(run) 582 583 584function lunit.loadonly() 585 clearstats() 586 report("begin") 587 report("done") 588 return stats 589end 590 591 592 593 594 595 596 597 598 599local lunitpat2luapat 600do 601 local conv = { 602 ["^"] = "%^", 603 ["$"] = "%$", 604 ["("] = "%(", 605 [")"] = "%)", 606 ["%"] = "%%", 607 ["."] = "%.", 608 ["["] = "%[", 609 ["]"] = "%]", 610 ["+"] = "%+", 611 ["-"] = "%-", 612 ["?"] = ".", 613 ["*"] = ".*" 614 } 615 function lunitpat2luapat(str) 616 --return "^" .. string.gsub(str, "%W", conv) .. "$" 617 -- Above was very annoying, if I want to run all the tests having to do with 618 -- RSS, I want to be able to do "-t rss" not "-t \*rss\*". 619 return string_gsub(str, "%W", conv) 620 end 621end 622 623 624 625local function in_patternmap(map, name) 626 if map[name] == true then 627 return true 628 else 629 for _, pat in ipairs(map) do 630 if string_find(name, pat) then 631 return true 632 end 633 end 634 end 635 return false 636end 637 638 639 640 641 642 643 644 645-- Called from 'lunit' shell script. 646 647function main(argv) 648 argv = argv or {} 649 650 -- FIXME: Error handling and error messages aren't nice. 651 652 local function checkarg(optname, arg) 653 if not is_string(arg) then 654 return error("lunit.main: option "..optname..": argument missing.", 0) 655 end 656 end 657 658 local function loadtestcase(filename) 659 if not is_string(filename) then 660 return error("lunit.main: invalid argument") 661 end 662 local chunk, err = loadfile(filename) 663 if err then 664 return error(err) 665 else 666 chunk() 667 end 668 end 669 670 local testpatterns = nil 671 local doloadonly = false 672 673 local i = 0 674 while i < #argv do 675 i = i + 1 676 local arg = argv[i] 677 if arg == "--loadonly" then 678 doloadonly = true 679 elseif arg == "--runner" or arg == "-r" then 680 local optname = arg; i = i + 1; arg = argv[i] 681 checkarg(optname, arg) 682 loadrunner(arg) 683 elseif arg == "--test" or arg == "-t" then 684 local optname = arg; i = i + 1; arg = argv[i] 685 checkarg(optname, arg) 686 testpatterns = testpatterns or {} 687 testpatterns[#testpatterns+1] = arg 688 elseif arg == "--help" or arg == "-h" then 689 print[[ 690lunit 0.5 691Copyright (c) 2004-2009 Michael Roth <[email protected]> 692This program comes WITHOUT WARRANTY OF ANY KIND. 693 694Usage: lua test [OPTIONS] [--] scripts 695 696Options: 697 698 -r, --runner RUNNER Testrunner to use, defaults to 'lunit-console'. 699 -t, --test PATTERN Which tests to run, may contain * or ? wildcards. 700 --loadonly Only load the tests. 701 -h, --help Print this help screen. 702 703Please report bugs to <[email protected]>. 704]] 705 return 706 elseif arg == "--" then 707 while i < #argv do 708 i = i + 1; arg = argv[i] 709 loadtestcase(arg) 710 end 711 else 712 loadtestcase(arg) 713 end 714 end 715 716 if doloadonly then 717 return loadonly() 718 else 719 return run(testpatterns) 720 end 721end 722 723clearstats() 724 725return lunit 726