1 package dev.sympho.modular_commands.impl.context;
2
3 import java.util.HashMap;
4 import java.util.Map;
5 import java.util.Map.Entry;
6 import java.util.concurrent.atomic.AtomicBoolean;
7 import java.util.function.Function;
8
9 import org.checkerframework.checker.interning.qual.FindDistinct;
10 import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
11 import org.checkerframework.checker.nullness.qual.NonNull;
12 import org.checkerframework.checker.nullness.qual.Nullable;
13 import org.checkerframework.dataflow.qual.Pure;
14 import org.checkerframework.dataflow.qual.SideEffectFree;
15 import org.slf4j.Logger;
16 import org.slf4j.LoggerFactory;
17
18 import dev.sympho.bot_utils.access.AccessManager;
19 import dev.sympho.bot_utils.event.AbstractRepliableContext;
20 import dev.sympho.bot_utils.event.reply.ReplyManager;
21 import dev.sympho.modular_commands.api.command.Command;
22 import dev.sympho.modular_commands.api.command.Invocation;
23 import dev.sympho.modular_commands.api.command.parameter.Parameter;
24 import dev.sympho.modular_commands.api.command.parameter.parse.ArgumentParser;
25 import dev.sympho.modular_commands.api.command.parameter.parse.AttachmentParser;
26 import dev.sympho.modular_commands.api.command.parameter.parse.BooleanParser;
27 import dev.sympho.modular_commands.api.command.parameter.parse.ChannelArgumentParser;
28 import dev.sympho.modular_commands.api.command.parameter.parse.FloatParser;
29 import dev.sympho.modular_commands.api.command.parameter.parse.IntegerParser;
30 import dev.sympho.modular_commands.api.command.parameter.parse.InvalidArgumentException;
31 import dev.sympho.modular_commands.api.command.parameter.parse.MessageArgumentParser;
32 import dev.sympho.modular_commands.api.command.parameter.parse.RoleArgumentParser;
33 import dev.sympho.modular_commands.api.command.parameter.parse.SnowflakeParser;
34 import dev.sympho.modular_commands.api.command.parameter.parse.StringParser;
35 import dev.sympho.modular_commands.api.command.parameter.parse.UserArgumentParser;
36 import dev.sympho.modular_commands.api.command.result.CommandFailureArgumentInvalid;
37 import dev.sympho.modular_commands.api.command.result.CommandFailureArgumentMissing;
38 import dev.sympho.modular_commands.api.command.result.CommandResult;
39 import dev.sympho.modular_commands.api.command.result.Results;
40 import dev.sympho.modular_commands.api.exception.ResultException;
41 import dev.sympho.modular_commands.execute.InstrumentedContext;
42 import dev.sympho.modular_commands.execute.LazyContext;
43 import dev.sympho.modular_commands.execute.Metrics;
44 import dev.sympho.modular_commands.utils.parse.ParseUtils;
45 import dev.sympho.reactor_utils.concurrent.ReactiveLatch;
46 import discord4j.common.util.Snowflake;
47 import discord4j.core.event.domain.Event;
48 import discord4j.core.object.entity.Attachment;
49 import discord4j.core.object.entity.Message;
50 import discord4j.core.object.entity.Role;
51 import discord4j.core.object.entity.User;
52 import discord4j.core.object.entity.channel.Channel;
53 import io.micrometer.observation.ObservationRegistry;
54 import reactor.core.observability.micrometer.Micrometer;
55 import reactor.core.publisher.Flux;
56 import reactor.core.publisher.Mono;
57
58
59
60
61
62
63
64
65
66 @SuppressWarnings( "MultipleStringLiterals" )
67 abstract class ContextImpl<A extends @NonNull Object, E extends @NonNull Event>
68 extends AbstractRepliableContext<E>
69 implements LazyContext, InstrumentedContext {
70
71
72 public static final String METRIC_NAME_PREFIX = "context";
73
74 public static final String METRIC_NAME_PREFIX_ARGUMENT = "argument";
75
76
77 private static final Logger LOGGER = LoggerFactory.getLogger( ContextImpl.class );
78
79
80 private static final String METRIC_NAME_INITIALIZE = Metrics.name( METRIC_NAME_PREFIX, "init" );
81
82 private static final String METRIC_NAME_LOAD = Metrics.name( METRIC_NAME_PREFIX, "load" );
83
84 private static final String METRIC_NAME_ARGUMENT_INIT = Metrics.name( METRIC_NAME_PREFIX,
85 METRIC_NAME_PREFIX_ARGUMENT, "init" );
86
87 private static final String METRIC_NAME_ARGUMENT_PARSE_ALL = Metrics.name( METRIC_NAME_PREFIX,
88 METRIC_NAME_PREFIX_ARGUMENT, "all" );
89
90 private static final String METRIC_NAME_ARGUMENT_PARSE_ONE = Metrics.name( METRIC_NAME_PREFIX,
91 METRIC_NAME_PREFIX_ARGUMENT, "one" );
92
93
94 private static final String METRIC_TAG_PARAMETER = Metrics.name( "parameter" );
95
96
97 protected final Command<?> command;
98
99
100 private final Invocation invocation;
101
102
103 private final Map<String, @Nullable Object> context;
104
105
106 private @MonotonicNonNull Map<String, ? extends Argument<?>> arguments;
107
108
109 private final AtomicBoolean initialized;
110
111
112 private final ReactiveLatch initializeLatch;
113
114
115 private @MonotonicNonNull Mono<CommandResult> loadResult;
116
117
118
119
120
121
122
123
124
125
126 protected ContextImpl(
127 final E event,
128 final Invocation invocation, final Command<?> command,
129 final AccessManager accessManager, final ReplyManager replyManager
130 ) {
131
132 super( event, accessManager, replyManager );
133
134 this.command = command;
135 this.invocation = invocation;
136
137 this.context = new HashMap<>();
138
139 this.arguments = null;
140
141 this.initialized = new AtomicBoolean( false );
142 this.initializeLatch = new ReactiveLatch();
143 this.loadResult = null;
144
145 }
146
147
148
149
150
151
152
153
154
155
156 @SideEffectFree
157 protected abstract Mono<String> getStringArgument( String name )
158 throws InvalidArgumentException;
159
160
161
162
163
164
165
166
167
168 @SideEffectFree
169 protected abstract Mono<Boolean> getBooleanArgument( String name )
170 throws InvalidArgumentException;
171
172
173
174
175
176
177
178
179
180 @SideEffectFree
181 protected abstract Mono<Long> getIntegerArgument( String name )
182 throws InvalidArgumentException;
183
184
185
186
187
188
189
190
191
192 @SideEffectFree
193 protected abstract Mono<Double> getFloatArgument( String name )
194 throws InvalidArgumentException;
195
196
197
198
199
200
201
202
203
204
205 @SideEffectFree
206 protected abstract Mono<Snowflake> getSnowflakeArgument( String name,
207 SnowflakeParser.Type type ) throws InvalidArgumentException;
208
209
210
211
212
213
214
215
216 @SideEffectFree
217 protected abstract Mono<User> getUserArgument( String name );
218
219
220
221
222
223
224
225
226 @SideEffectFree
227 protected abstract Mono<Role> getRoleArgument( String name );
228
229
230
231
232
233
234
235
236
237
238 @SideEffectFree
239 protected abstract <C extends @NonNull Channel> Mono<C> getChannelArgument( String name,
240 Class<C> type );
241
242
243
244
245
246
247
248
249
250 @SideEffectFree
251 protected Mono<Message> getMessageArgument( final String name ) {
252
253 return getStringArgument( name )
254 .flatMap( raw -> ParseUtils.MESSAGE.parse( this, raw ) );
255
256 }
257
258
259
260
261
262
263
264 @SideEffectFree
265 protected abstract Mono<Attachment> getAttachmentArgument( String name );
266
267
268
269
270
271
272
273
274
275
276
277
278
279 @SideEffectFree
280 private <R extends @NonNull Object> Mono<R> handleMissingArgument(
281 final Parameter<?> parameter ) {
282
283 if ( parameter.required() ) {
284 return Mono.error( () -> new ResultException(
285 new CommandFailureArgumentMissing( parameter )
286 ) );
287 } else {
288 return Mono.empty();
289 }
290
291 }
292
293
294
295
296
297
298
299
300 @SideEffectFree
301 private ResultException wrapInvalidParam( final Parameter<?> parameter,
302 final InvalidArgumentException exception ) {
303
304 LOGGER.trace( "Invalid argument for parameter {}: {}", parameter, exception.getMessage() );
305 final var error = exception.getMessage();
306 final var result = new CommandFailureArgumentInvalid( parameter, error );
307 return new ResultException( result );
308
309 }
310
311
312
313
314
315
316
317
318 @SideEffectFree
319 private ResultException wrapParamError( final Parameter<?> parameter, final Throwable error ) {
320
321
322 if ( error instanceof ResultException res ) {
323 LOGGER.warn( "Result exception would be wrapped" );
324 return res;
325 }
326
327 LOGGER.error( "Error while parsing parameter {}: {}", parameter, error.getMessage() );
328 return new ResultException( Results.exceptionR( error ) );
329
330 }
331
332
333
334
335
336
337
338
339
340
341
342
343
344 @SideEffectFree
345 @SuppressWarnings( { "conditional", "return" } )
346 private <R extends @NonNull Object, T extends @NonNull Object> Mono<T> parseArgument(
347 final Parameter<T> parameter,
348 final Function<String, Mono<R>> getter,
349 final ArgumentParser<R, T> parser
350 ) {
351
352 return getter.apply( parameter.name() )
353 .switchIfEmpty( handleMissingArgument( parameter ) )
354 .map( parser::validateRaw )
355
356 .flatMap( raw -> parser.parse( this, raw ) )
357 .switchIfEmpty( Mono.justOrEmpty( parameter.defaultValue() ) )
358 .onErrorMap( InvalidArgumentException.class,
359 e -> wrapInvalidParam( parameter, e )
360 )
361 .onErrorMap( e -> !( e instanceof ResultException ),
362 e -> wrapParamError( parameter, e )
363 );
364
365 }
366
367
368
369
370
371
372
373
374
375
376
377
378
379 private <C extends @NonNull Channel, T extends @NonNull Object> Mono<T> parseArgument(
380 final Parameter<T> parameter,
381 final ChannelArgumentParser<C, T> parser
382 ) {
383
384 return parseArgument(
385 parameter,
386 name -> getChannelArgument( name, parser.type() ),
387 parser
388 );
389
390 }
391
392
393
394
395
396
397
398
399
400
401 @SideEffectFree
402 @SuppressWarnings( { "JavadocMethod", "unchecked" } )
403 private <T extends @NonNull Object> Mono<T> parseArgument( final Parameter<T> parameter ) {
404
405
406
407
408
409 final var parser = parameter.parser();
410 if ( parser instanceof AttachmentParser<?> p ) {
411 return parseArgument( parameter, this::getAttachmentArgument,
412 ( AttachmentParser<T> ) p );
413 } else if ( parser instanceof StringParser<?> p ) {
414 return parseArgument( parameter, this::getStringArgument, ( StringParser<T> ) p );
415 } else if ( parser instanceof BooleanParser<?> p ) {
416 return parseArgument( parameter, this::getBooleanArgument, ( BooleanParser<T> ) p );
417 } else if ( parser instanceof IntegerParser<?> p ) {
418 return parseArgument( parameter, this::getIntegerArgument, ( IntegerParser<T> ) p );
419 } else if ( parser instanceof FloatParser<?> p ) {
420 return parseArgument( parameter, this::getFloatArgument, ( FloatParser<T> ) p );
421 } else if ( parser instanceof SnowflakeParser<?> p ) {
422 return parseArgument( parameter, n -> getSnowflakeArgument( n, p.type() ),
423 ( SnowflakeParser<T> ) p );
424 } else if ( parser instanceof UserArgumentParser<?> p ) {
425 return parseArgument( parameter, this::getUserArgument,
426 ( UserArgumentParser<T> ) p );
427 } else if ( parser instanceof RoleArgumentParser<?> p ) {
428 return parseArgument( parameter, this::getRoleArgument,
429 ( RoleArgumentParser<T> ) p );
430 } else if ( parser instanceof ChannelArgumentParser<?, ?> p ) {
431 return parseArgument( parameter, ( ChannelArgumentParser<?, T> ) p );
432 } else if ( parser instanceof MessageArgumentParser<?> p ) {
433 return parseArgument( parameter, this::getMessageArgument,
434 ( MessageArgumentParser<T> ) p );
435 } else {
436 throw new IllegalArgumentException( "Unrecognized parser type: " + parser.getClass() );
437 }
438
439 }
440
441
442
443
444
445
446
447
448
449
450 @SideEffectFree
451 @SuppressWarnings( {
452 "return",
453 "optional.parameter"
454 } )
455 private <T extends @NonNull Object> Mono<? extends Entry<String, ? extends Argument<T>>>
456 processArgument( final Parameter<T> parameter ) {
457
458 return parseArgument( parameter )
459 .singleOptional()
460 .map( v -> new Argument<>( parameter, v.orElse( null ) ) )
461 .map( a -> Map.entry( parameter.name(), a ) )
462 .doOnError( t -> {
463 if ( t instanceof ResultException ex ) {
464 LOGGER.trace( "Arg processing aborted: {}", ex.getResult() );
465 } else {
466 LOGGER.error( "Failed to process argument", t );
467 }
468 } );
469
470 }
471
472
473
474
475
476
477
478
479
480 protected abstract Mono<Void> initArgs();
481
482
483
484 @Override
485 public String getCommandId() {
486
487 return command.id();
488
489 }
490
491 @Override
492 public Invocation invocation() {
493
494 return invocation;
495
496 }
497
498 @Override
499 public Invocation commandInvocation() {
500
501 return command.invocation();
502
503 }
504
505
506
507
508
509
510
511
512
513 @Pure
514 private Argument<?> getArgument( final String name )
515 throws IllegalStateException, IllegalArgumentException {
516
517 if ( arguments == null ) {
518 throw new IllegalStateException( "Context not loaded yet" );
519 }
520
521 final var arg = arguments.get( name );
522 if ( arg == null ) {
523 throw new IllegalArgumentException( String.format( "No parameter named '%s'", name ) );
524 } else {
525 return arg;
526 }
527
528 }
529
530 @Override
531 public <T extends @NonNull Object> @Nullable T getArgument(
532 final String name, final Class<T> argumentType )
533 throws IllegalArgumentException, ClassCastException {
534
535 return getArgument( name ).getValue( argumentType );
536
537 }
538
539 @Override
540 public <T extends @NonNull Object> @Nullable T getArgument(
541 final @FindDistinct Parameter<? extends T> parameter ) throws IllegalArgumentException {
542
543 final var argument = getArgument( parameter.name() );
544 if ( argument.parameter() == parameter ) {
545
546
547 @SuppressWarnings( "unchecked" )
548 final Argument<T> arg = ( Argument<T> ) argument;
549 return arg.value();
550 } else {
551 throw new IllegalArgumentException(
552 "Parameter does not match definition: " + parameter );
553 }
554
555 }
556
557 @Override
558 public boolean setContext( final String key, final @Nullable Object obj,
559 final boolean replace ) {
560
561 if ( !replace && context.containsKey( key ) ) {
562 return false;
563 } else {
564 context.put( key, obj );
565 return true;
566 }
567
568 }
569
570 @Override
571 @SuppressWarnings( "signedness:return" )
572 public <T> @Nullable T getContext( final String key, final Class<? extends T> type )
573 throws IllegalArgumentException, ClassCastException {
574
575 if ( !context.containsKey( key ) ) {
576 throw new IllegalArgumentException( String.format(
577 "No context under key '%s'.", key ) );
578 }
579
580 return type.cast( context.get( key ) );
581
582 }
583
584
585
586
587
588
589 private void doInitialize( final ObservationRegistry observations ) {
590
591 LOGGER.trace( "Initializing context" );
592
593
594
595 this.loadResult = Mono.just( observations ).flatMap( this::doLoad ).cache();
596
597 LOGGER.trace( "Context initialized" );
598
599 }
600
601 @Override
602 public Mono<Void> initialize( final ObservationRegistry observations ) {
603
604 if ( initialized.getAndSet( true ) ) {
605 return initializeLatch.await();
606 }
607
608 LOGGER.trace( "Initializing context" );
609
610
611
612 this.loadResult = Mono.just( observations ).flatMap( this::doLoad ).cache();
613
614 return Mono.fromRunnable( () -> doInitialize( observations ) )
615 .then()
616 .doOnSuccess( v -> initializeLatch.countDown() )
617 .doOnError( initializeLatch::fail )
618 .doOnError( t -> LOGGER.error( "Failed to initialize", t ) )
619 .checkpoint( METRIC_NAME_INITIALIZE )
620 .name( METRIC_NAME_INITIALIZE )
621 .transform( this::addTags )
622 .tap( Micrometer.observation( observations ) )
623 .cache();
624
625 }
626
627
628
629
630
631
632
633
634 @SuppressWarnings( "assignment" )
635 public Mono<CommandResult> doLoad( final ObservationRegistry observations ) {
636
637 LOGGER.trace( "Loading context" );
638
639 final var init = Mono.defer( () -> initArgs()
640 .checkpoint( METRIC_NAME_ARGUMENT_INIT )
641 .name( METRIC_NAME_ARGUMENT_INIT )
642 .transform( this::addTags )
643 .tap( Micrometer.observation( observations ) )
644 );
645
646 final var parse = Mono.defer( () -> Flux.fromIterable( command.parameters() )
647 .flatMap( p -> processArgument( p )
648 .checkpoint( METRIC_NAME_ARGUMENT_PARSE_ONE )
649 .name( METRIC_NAME_ARGUMENT_PARSE_ONE )
650 .transform( this::addTags )
651 .tag( METRIC_TAG_PARAMETER, p.name() )
652 .tap( Micrometer.observation( observations ) )
653 )
654 .collectMap( Entry::getKey, Entry::getValue )
655 .doOnNext( args -> {
656 this.arguments = args;
657 } )
658 .checkpoint( METRIC_NAME_ARGUMENT_PARSE_ALL )
659 .name( METRIC_NAME_ARGUMENT_PARSE_ALL )
660 .transform( this::addTags )
661 .tap( Micrometer.observation( observations ) )
662 );
663
664 return init.then( parse )
665 .checkpoint( METRIC_NAME_LOAD )
666 .name( METRIC_NAME_LOAD )
667 .transform( this::addTags )
668 .tap( Micrometer.observation( observations ) )
669 .doOnSuccess( v -> LOGGER.trace( "Context loaded" ) )
670 .doOnError( t -> {
671 if ( t instanceof ResultException ex ) {
672 LOGGER.trace( "Load aborted: {}", ex.getResult() );
673 } else {
674 LOGGER.error( "Failed to load", t );
675 }
676 } )
677 .then( Mono.empty().cast( CommandResult.class ) )
678 .onErrorResume( ResultException.class, e -> Mono.just( e.getResult() ) );
679
680 }
681
682 @Override
683 public Mono<CommandResult> load() {
684
685 if ( loadResult == null ) {
686 throw new IllegalStateException( "Called load() before initialize()" );
687 } else {
688 return loadResult;
689 }
690
691 }
692
693
694
695
696
697
698
699
700
701
702 private record Argument<T extends @NonNull Object>(
703 Parameter<T> parameter,
704 @Nullable T value
705 ) {
706
707
708
709
710
711
712
713
714
715 @SuppressWarnings( "signedness:return" )
716 public <E> @Nullable E getValue( final Class<E> argumentType ) throws ClassCastException {
717
718 return argumentType.cast( value );
719
720 }
721
722 }
723
724 }