main.rs 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. mod pattern;
  2. use walkdir::WalkDir;
  3. #[derive(Debug)]
  4. struct Args {
  5. targets: Vec<String>,
  6. num_threads: u8,
  7. }
  8. impl Args {
  9. fn parse() -> Args {
  10. let matches = clap::App::new("rubygrep")
  11. .version("0.1")
  12. .author("Getty Ritter <rubygrep@infinitenegativeutility.com>")
  13. .about("Search Ruby source trees")
  14. .arg(clap::Arg::with_name("threads")
  15. .short("j")
  16. .value_name("THREADS")
  17. .help("Spawn the specified number of worker threads")
  18. .takes_value(true))
  19. .arg(clap::Arg::with_name("targets")
  20. .help("The Ruby sources and directories to search")
  21. .multiple(true))
  22. .get_matches();
  23. let targets = if let Some(values) = matches.values_of("targets") {
  24. values.map(|x| x.to_owned()).collect::<Vec<String>>()
  25. } else {
  26. vec![".".to_owned()]
  27. };
  28. let num_threads = if let Some(val) = matches.value_of("threads") {
  29. match val.parse() {
  30. Ok(x) => x,
  31. Err(_) => {
  32. panic!("Invalid number of threads: {}", val);
  33. }
  34. }
  35. } else {
  36. 8
  37. };
  38. Args { targets, num_threads }
  39. }
  40. }
  41. fn is_ruby_source(path: &std::path::Path) -> bool {
  42. if let Some(ext) = path.extension() {
  43. return ext == "rb";
  44. }
  45. false
  46. }
  47. fn main() {
  48. let args = Args::parse();
  49. println!("Got args: {:?}", args);
  50. let Args {
  51. targets,
  52. num_threads,
  53. } = args;
  54. let p = pattern::Pattern {
  55. node_type: pattern::NodeType::Send,
  56. name: Some("ActiveRecord".to_owned()),
  57. children: vec![],
  58. rest: false,
  59. };
  60. let (work_send, work_recv) = crossbeam::channel::unbounded();
  61. let (result_send, result_recv) = crossbeam::channel::unbounded();
  62. // produce a thread which populates the channel with work
  63. let producer = std::thread::spawn(move || {
  64. for target in targets {
  65. for entry in WalkDir::new(target).into_iter().filter_map(|e| e.ok()) {
  66. if is_ruby_source(entry.path()) {
  67. work_send.send(entry.into_path()).expect("Unable to send work from producer thread");
  68. }
  69. }
  70. }
  71. });
  72. // produce a set of threads which can grab work to be done
  73. let workers: Vec<std::thread::JoinHandle<()>> = (0..num_threads).map(|id| {
  74. let receiver = work_recv.clone();
  75. let sender = result_send.clone();
  76. let pat = p.clone();
  77. std::thread::spawn(move || {
  78. while let Ok(ref msg) = receiver.recv() {
  79. use std::io::Read;
  80. let mut buf = Vec::new();
  81. {
  82. let mut f = std::fs::File::open(msg).expect("Unable to read file");
  83. f.read_to_end(&mut buf).expect("Unable to read file");
  84. }
  85. let parser = lib_ruby_parser::Parser::new(buf, std::default::Default::default());
  86. let lib_ruby_parser::ParserResult { ast, diagnostics, .. } = parser.do_parse();
  87. if let Some(ast) = ast {
  88. let matches = pat.find_matches(&*ast);
  89. if matches.len() > 0 {
  90. sender.send(format!("{:?}: found matches: {:?}", msg, matches)).unwrap();
  91. }
  92. } else {
  93. // sender.send(format!("Unable to parse {:?}", msg)).unwrap();
  94. }
  95. }
  96. })
  97. }).collect();
  98. drop(result_send);
  99. while let Ok(msg) = result_recv.recv() {
  100. println!("{}", msg);
  101. }
  102. // join all the threads
  103. producer.join().expect("Producer thread panicked!");
  104. for w in workers {
  105. w.join().expect("Worker thread panicked!");
  106. }
  107. }