#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0116,W0613
#
# This program is free software; you can redistribute it and/or modify the
# Verilator internals under the terms of either the GNU Lesser General
# Public License Version 3 or the Perl Artistic License Version 2.0.
#
# SPDX-FileCopyrightText: 2003-2026 Wilson Snyder
# SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0
######################################################################

import argparse
import os
import subprocess
import sys

try:
    from termcolor import colored
except ModuleNotFoundError:

    def colored(msg, **kwargs):
        return msg


def cprint(msg="", *, color=None, attrs=None, **kwargs):
    print(colored(msg, color=color, attrs=attrs), **kwargs)


parser = argparse.ArgumentParser(
    description='Binary search utility for debugging Verilator with V3DebugBisect',
    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog='''
example:
  %(prog)s DfgPeephole 0 1000 test_regress/t/t_foo.py --no-skip-identical
    ''')

parser.add_argument("--pre", type=str, help='Command to run before each execution')
parser.add_argument('name', help='Name of V3DebugBisect instance')
parser.add_argument('low', type=int, help='Bisection range low value, use 0 by default')
parser.add_argument('high',
                    type=int,
                    help='Bisection range high value, use a sufficiently high number')
parser.add_argument('cmd',
                    nargs=argparse.REMAINDER,
                    help='Discriminator command that should exit non-zero on failure')

args = parser.parse_args()

var = f"VERILATOR_DEBUG_BISECT_{args.name}"
passing = args.low - 1
failing = args.high + 1

cprint()
cprint(f"Starting bisection serach for {var} in interval [{args.low}, {args.high}]",
       attrs=["bold"])

while True:
    cprint()

    passStr = str(passing) if passing >= args.low else '?'
    failStr = str(failing) if failing < args.high else '?'
    cprint(f"Current step Pass: {passStr} Fail: {failStr}", attrs=["bold"])

    # Stop if found, or exhausted interval without finding both a pass and a fail
    if failing == args.low:
        cprint(f"The low endpoint of the search interval ({args.low}) fails. Suggest rerun with:",
               color="yellow")
        cprint(f"   {sys.argv[0]} {args.name} 0 {args.low} ...", color="yellow")
        sys.exit(1)
    if passing == args.high:
        cprint(
            f"The high endpoint of the search interval ({args.high}) passes. Suggest rerun with:",
            color="yellow")
        cprint(f"   {sys.argv[0]} {args.name} {args.high} {10*args.high} ...", color="yellow")
        sys.exit(1)
    if failing == passing + 1:
        cprint(f"First faling value: {var}={failing}", attrs=["bold"])
        sys.exit(0)

    # Compute middle of interval to evaluate
    mid = (failing + passing) // 2

    # Run pre command if given:
    if args.pre:
        cprint("Running --pre command", attrs=["bold"])
        preResult = subprocess.run(args.pre, shell=True, check=False)
        if preResult.returncode != 0:
            cprint("Pre command failed", color="red")
            sys.exit(2)

    # Set up environment variable
    env = os.environ.copy()
    env[var] = str(mid)

    # Run the discriminator command
    cprint(f"Running with {var}={mid}", attrs=["bold"])
    result = subprocess.run(args.cmd, env=env, check=False)

    # Check status, update interval
    if result.returncode != 0:
        cprint(f"Run with {var}={mid}: Fail", color="red")
        failing = mid
    else:
        cprint(f"Run with {var}={mid}: Pass", color="green")
        passing = mid
