diff --git a/bandit/core/utils.py b/bandit/core/utils.py index a8cbbbbc..e85ca191 100644 --- a/bandit/core/utils.py +++ b/bandit/core/utils.py @@ -297,3 +297,47 @@ def linerange_fix(node): if delta > 1: return range(start, node.sibling.lineno) return lines + + +def concat_string(node, stop=None): + '''Builds a string from a ast.BinOp chain. + + This will build a string from a series of ast.Str nodes wrapped in + ast.BinOp nodes. Somthing like "a" + "b" + "c" or "a %s" % val etc. + The provided node can be any participant in the BinOp chain. + + :param node: (ast.Str or ast.BinOp) The node to process + :param stop: (ast.Str or ast.BinOp) Optional base node to stop at + :returns: (Tuple) the root node of the expression, the string value + ''' + def _get(node, bits, stop=None): + if node != stop: + bits.append( + _get(node.left, bits, stop) + if isinstance(node.left, ast.BinOp) + else node.left) + bits.append( + _get(node.right, bits, stop) + if isinstance(node.right, ast.BinOp) + else node.right) + + bits = [node] + while isinstance(node.parent, ast.BinOp): + node = node.parent + if isinstance(node, ast.BinOp): + _get(node, bits, stop) + return (node, " ".join([x.s for x in bits if isinstance(x, ast.Str)])) + + +def get_called_name(node): + '''Get a function name from an ast.Call node. + + An ast.Call node representing a method call with present differently to one + wrapping a function call: thing.call() vs call(). This helper will grab the + unqualified call name correctly in either case. + + :param node: (ast.Call) the call node + :returns: (String) the function name + ''' + func = node.func + return (func.attr if isinstance(func, ast.Attribute) else func.id) diff --git a/bandit/plugins/injection_sql.py b/bandit/plugins/injection_sql.py index 0cd49d6a..e3d1fc8d 100644 --- a/bandit/plugins/injection_sql.py +++ b/bandit/plugins/injection_sql.py @@ -14,75 +14,40 @@ # License for the specific language governing permissions and limitations # under the License. +import ast + import bandit -from bandit.core.test_properties import * +from bandit.core import test_properties as test from bandit.core import utils -def _ast_build_string(data): - # used to return a string representation of AST data - - if isinstance(data, ast.Str): - # Already a string, just return the value - return utils.safe_str(data.s) - - if isinstance(data, ast.BinOp): - # need to build the string from a binary operation - return _ast_binop_stringify(data) - - if isinstance(data, ast.Name): - # a variable, stringify the variable name - return "[[" + utils.safe_str(data.id) + "]]" - - return "XXX" # placeholder for unaccounted for values +def _check_string(data): + val = data.lower() + return ((val.startswith('select ') and ' from ' in val) or + val.startswith('insert into') or + (val.startswith('update ') and ' set ' in val) or + val.startswith('delete from ')) -def _ast_binop_stringify(data): - # used to recursively build a string from a binary operation - left = data.left - right = data.right +def _evaluate_ast(node): + if not isinstance(node.parent, ast.BinOp): + return (False, "") - return _ast_build_string(left) + _ast_build_string(right) + out = utils.concat_string(node, node.parent) + if isinstance(out[0].parent, ast.Call): # wrapped in "execute" call? + names = ['execute', 'executemany'] + name = utils.get_called_name(out[0].parent) + return (name in names, out[1]) + return (False, out[1]) -@checks('Str') +@test.checks('Str') def hardcoded_sql_expressions(context): - statement = context.node.parent - if isinstance(statement, ast.Assign): - test_str = _ast_build_string(statement.value).lower() - - elif isinstance(statement, ast.Expr): - test_str = "" - if isinstance(statement.value, ast.Call): - ctx_str = context.string_val.lower() - for arg in statement.value.args: - temp_str = _ast_build_string(arg).lower() - if ctx_str in temp_str: - test_str = temp_str - else: - test_str = context.string_val.lower() - - if ( - (test_str.startswith('select ') and ' from ' in test_str) or - test_str.startswith('insert into') or - (test_str.startswith('update ') and ' set ' in test_str) or - test_str.startswith('delete from ') - ): - # if sqlalchemy is not imported and it looks like they are using SQL - # statements, mark it as a medium severity issue - if not context.is_module_imported_like("sqlalchemy"): - return bandit.Issue( - severity=bandit.MEDIUM, - confidence=bandit.LOW, - text="Possible SQL injection vector through string-based " - "query construction, without SQLAlchemy use." - ) - - # otherwise, if sqlalchemy is being used, mark it as low severity - else: - return bandit.Issue( - severity=bandit.LOW, - confidence=bandit.LOW, - text="Possible SQL injection vector through string-based " - "query construction." - ) + val = _evaluate_ast(context.node) + if _check_string(val[1]): + return bandit.Issue( + severity=bandit.MEDIUM, + confidence=bandit.MEDIUM if val[0] else bandit.LOW, + text="Possible SQL injection vector through string-based " + "query construction." + ) diff --git a/examples/sql_statements.py b/examples/sql_statements.py new file mode 100644 index 00000000..5447270a --- /dev/null +++ b/examples/sql_statements.py @@ -0,0 +1,29 @@ +import sqlalchemy + +# bad +query = "SELECT * FROM foo WHERE id = '%s'" % identifier +query = "INSERT INTO foo VALUES ('a', 'b', '%s')" % value +query = "DELETE FROM foo WHERE id = '%s'" % identifier +query = "UPDATE foo SET value = 'b' WHERE id = '%s'" % identifier + +# bad +cur.execute("SELECT * FROM foo WHERE id = '%s'" % identifier) +cur.execute("INSERT INTO foo VALUES ('a', 'b', '%s')" % value) +cur.execute("DELETE FROM foo WHERE id = '%s'" % identifier) +cur.execute("UPDATE foo SET value = 'b' WHERE id = '%s'" % identifier) + +# good +cur.execute("SELECT * FROM foo WHERE id = '%s'", identifier) +cur.execute("INSERT INTO foo VALUES ('a', 'b', '%s')", value) +cur.execute("DELETE FROM foo WHERE id = '%s'", identifier) +cur.execute("UPDATE foo SET value = 'b' WHERE id = '%s'", identifier) + +# bad +query = "SELECT " + val + " FROM " + val +" WHERE id = " + val + +# bad +cur.execute("SELECT " + val + " FROM " + val +" WHERE id = " + val) + +# real world false positives +choices=[('server_list', _("Select from active instances"))] +print("delete from the cache as the first argument") diff --git a/examples/sql_statements_with_sqlalchemy.py b/examples/sql_statements_with_sqlalchemy.py deleted file mode 100644 index 4d0fbb16..00000000 --- a/examples/sql_statements_with_sqlalchemy.py +++ /dev/null @@ -1,6 +0,0 @@ -import sqlalchemy - -query = "SELECT * FROM foo WHERE id = '%s'" % identifier -query = "INSERT INTO foo VALUES ('a', 'b', '%s')" % value -query = "DELETE FROM foo WHERE id = '%s'" % identifier -query = "UPDATE foo SET value = 'b' WHERE id = '%s'" % identifier diff --git a/examples/sql_statements_without_sql_alchemy.py b/examples/sql_statements_without_sql_alchemy.py deleted file mode 100644 index bcc51d07..00000000 --- a/examples/sql_statements_without_sql_alchemy.py +++ /dev/null @@ -1,4 +0,0 @@ -query = "SELECT * FROM foo WHERE id = '%s'" % identifier -query = "INSERT INTO foo VALUES ('a', 'b', '%s')" % value -query = "DELETE FROM foo WHERE id = '%s'" % identifier -query = "UPDATE foo SET value = 'b' WHERE id = '%s'" % identifier diff --git a/tests/test_functional.py b/tests/test_functional.py index a90fb885..7ef04b2b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -249,15 +249,12 @@ class FunctionalTests(unittest.TestCase): expect = {'SEVERITY': {'LOW': 5}, 'CONFIDENCE': {'HIGH': 5}} self.check_example('skip.py', expect) - def test_sql_statements_with_sqlalchemy(self): + def test_sql_statements(self): '''Test for SQL injection through string building.''' - expect = {'SEVERITY': {'LOW': 4}, 'CONFIDENCE': {'LOW': 4}} - self.check_example('sql_statements_with_sqlalchemy.py', expect) - - def test_sql_statements_without_sql_alchemy(self): - '''Test for SQL injection without SQLAlchemy.''' - expect = {'SEVERITY': {'MEDIUM': 4}, 'CONFIDENCE': {'LOW': 4}} - self.check_example('sql_statements_without_sql_alchemy.py', expect) + expect = { + 'SEVERITY': {'MEDIUM': 10}, + 'CONFIDENCE': {'LOW': 5, 'MEDIUM': 5}} + self.check_example('sql_statements.py', expect) def test_ssl_insecure_version(self): '''Test for insecure SSL protocol versions.'''